Changeset - eeea39f48a29
[Not reviewed]
Cargo.toml
Show inline comments
 
[package]
 
name = "reowolf_rs"
 
version = "1.2.0"
 
authors = [
 
	"Max Henger <henger@cwi.nl>",
 
	"Christopher Esterhuyse <esterhuy@cwi.nl>",
 
	"Hans-Dieter Hiep <hdh@cwi.nl>"
 
]
 
edition = "2021"
 

	
 
[dependencies]
 
# convenience macros
 
maplit = "1.0.2"
 
derive_more = "0.99.2"
 

	
 
# runtime
 
bincode = "1.3.1"
 
serde = { version = "1.0.114", features = ["derive"] }
 
getrandom = "0.1.14" # tiny crate. used to guess controller-id
 

	
 
# network
 
mio = { version = "0.7.0", package = "mio", features = ["udp", "tcp", "os-poll"] }
 
socket2 = { version = "0.3.12", optional = true }
 

	
 
# protocol
 
backtrace = "0.3"
 
lazy_static = "1.4.0"
 

	
 
# ffi
 

	
 
# socket ffi
 
libc = { version = "^0.2", optional = true }
 
os_socketaddr = { version = "0.1.0", optional = true }
 

	
 
[dev-dependencies]
 
# randomness
 
rand = "0.8.4"
 
rand_pcg = "0.3.1"
 

	
 
[lib]
 
crate-type = [
 
	"rlib", # for use as a Rust dependency.
 
]
 
\ No newline at end of file
docs/runtime/sync.md
Show inline comments
 
new file 100644
 
# Synchronous Communication
 

	
 
## 
 
\ No newline at end of file
src/common.rs
Show inline comments
 
deleted file
src/lib.rs
Show inline comments
 
#[macro_use]
 
mod macros;
 

	
 
// mod common;
 
mod protocol;
 
pub mod runtime;
 
pub mod runtime2;
 
mod collections;
 
mod random;
 

	
 
pub use protocol::{ProtocolDescription, ProtocolDescriptionBuilder, ComponentCreationError};
 
\ No newline at end of file
src/protocol/eval/executor.rs
Show inline comments
 
@@ -547,403 +547,403 @@ impl Prompt {
 
                                        _ => unreachable!("got concrete type {:?} for integer literal at expr {:?}", concrete_type, expr_id),
 
                                    }
 
                                }
 
                                Literal::Struct(lit_value) => {
 
                                    let heap_pos = transfer_expression_values_front_into_heap(
 
                                        cur_frame, &mut self.store, lit_value.fields.len()
 
                                    );
 
                                    Value::Struct(heap_pos)
 
                                }
 
                                Literal::Enum(lit_value) => {
 
                                    Value::Enum(lit_value.variant_idx as i64)
 
                                }
 
                                Literal::Union(lit_value) => {
 
                                    let heap_pos = transfer_expression_values_front_into_heap(
 
                                        cur_frame, &mut self.store, lit_value.values.len()
 
                                    );
 
                                    Value::Union(lit_value.variant_idx as i64, heap_pos)
 
                                }
 
                                Literal::Array(lit_value) => {
 
                                    let heap_pos = transfer_expression_values_front_into_heap(
 
                                        cur_frame, &mut self.store, lit_value.len()
 
                                    );
 
                                    Value::Array(heap_pos)
 
                                }
 
                                Literal::Tuple(lit_value) => {
 
                                    let heap_pos = transfer_expression_values_front_into_heap(
 
                                        cur_frame, &mut self.store, lit_value.len()
 
                                    );
 
                                    Value::Tuple(heap_pos)
 
                                }
 
                            };
 

	
 
                            cur_frame.expr_values.push_back(value);
 
                        },
 
                        Expression::Cast(expr) => {
 
                            let mono_data = &heap[cur_frame.definition].monomorphs[cur_frame.monomorph_index];
 
                            let type_id = mono_data.expr_info[expr.type_index as usize].type_id;
 
                            let concrete_type = &types.get_monomorph(type_id).concrete_type;
 

	
 
                            // Typechecking reduced this to two cases: either we
 
                            // have casting noop (same types), or we're casting
 
                            // between integer/bool/char types.
 
                            let subject = cur_frame.expr_values.pop_back().unwrap();
 
                            match apply_casting(&mut self.store, concrete_type, &subject) {
 
                                Ok(value) => cur_frame.expr_values.push_back(value),
 
                                Err(msg) => {
 
                                    return Err(EvalError::new_error_at_expr(self, modules, heap, expr.this.upcast(), msg));
 
                                }
 
                            }
 

	
 
                            self.store.drop_value(subject.get_heap_pos());
 
                        }
 
                        Expression::Call(expr) => {
 
                            // If we're dealing with a builtin we don't do any
 
                            // fancy shenanigans at all, just push the result.
 
                            match expr.method {
 
                                Method::Get => {
 
                                    let value = cur_frame.expr_values.pop_front().unwrap();
 
                                    let value = self.store.maybe_read_ref(&value).clone();
 

	
 
                                    let port_id = if let Value::Input(port_id) = value {
 
                                        port_id
 
                                    } else {
 
                                        unreachable!("executor calling 'get' on value {:?}", value)
 
                                    };
 

	
 
                                    match ctx.performed_get(port_id) {
 
                                        Some(result) => {
 
                                            // We have the result. Merge the `ValueGroup` with the
 
                                            // stack/heap storage.
 
                                            debug_assert_eq!(result.values.len(), 1);
 
                                            result.into_stack(&mut cur_frame.expr_values, &mut self.store);
 
                                        },
 
                                        None => {
 
                                            // Don't have the result yet, prepare the expression to
 
                                            // get run again after we've received a message.
 
                                            cur_frame.expr_values.push_front(value.clone());
 
                                            cur_frame.expr_stack.push_back(ExprInstruction::EvalExpr(expr_id));
 
                                            return Ok(EvalContinuation::BlockGet(port_id));
 
                                        }
 
                                    }
 
                                },
 
                                Method::Put => {
 
                                    let port_value = cur_frame.expr_values.pop_front().unwrap();
 
                                    let deref_port_value = self.store.maybe_read_ref(&port_value).clone();
 

	
 
                                    let port_id = if let Value::Output(port_id) = deref_port_value {
 
                                        port_id
 
                                    } else {
 
                                        unreachable!("executor calling 'put' on value {:?}", deref_port_value)
 
                                    };
 

	
 
                                    let msg_value = cur_frame.expr_values.pop_front().unwrap();
 
                                    let deref_msg_value = self.store.maybe_read_ref(&msg_value).clone();
 

	
 
                                    if ctx.performed_put(port_id) {
 
                                        // We're fine, deallocate in case the expression value stack
 
                                        // held an owned value
 
                                        self.store.drop_value(msg_value.get_heap_pos());
 
                                    } else {
 
                                        // Prepare to execute again
 
                                        cur_frame.expr_values.push_front(msg_value);
 
                                        cur_frame.expr_values.push_front(port_value);
 
                                        cur_frame.expr_stack.push_back(ExprInstruction::EvalExpr(expr_id));
 
                                        let value_group = ValueGroup::from_store(&self.store, &[deref_msg_value]);
 
                                        return Ok(EvalContinuation::Put(port_id, value_group));
 
                                    }
 
                                },
 
                                Method::Fires => {
 
                                    let port_value = cur_frame.expr_values.pop_front().unwrap();
 
                                    let port_value_deref = self.store.maybe_read_ref(&port_value).clone();
 
                                    let port_id = port_value_deref.as_port_id();
 

	
 
                                    match ctx.fires(port_id) {
 
                                        None => {
 
                                            cur_frame.expr_values.push_front(port_value);
 
                                            cur_frame.expr_stack.push_back(ExprInstruction::EvalExpr(expr_id));
 
                                            return Ok(EvalContinuation::BlockFires(port_id));
 
                                        },
 
                                        Some(value) => {
 
                                            cur_frame.expr_values.push_back(value);
 
                                        }
 
                                    }
 
                                },
 
                                Method::Create => {
 
                                    let length_value = cur_frame.expr_values.pop_front().unwrap();
 
                                    let length_value = self.store.maybe_read_ref(&length_value);
 
                                    let length = if length_value.is_signed_integer() {
 
                                        let length_value = length_value.as_signed_integer();
 
                                        if length_value < 0 {
 
                                            return Err(EvalError::new_error_at_expr(
 
                                                self, modules, heap, expr_id,
 
                                                format!("got length '{}', can only create a message with a non-negative length", length_value)
 
                                            ));
 
                                        }
 

	
 
                                        length_value as u64
 
                                    } else {
 
                                        debug_assert!(length_value.is_unsigned_integer());
 
                                        length_value.as_unsigned_integer()
 
                                    };
 

	
 
                                    let heap_pos = self.store.alloc_heap();
 
                                    let values = &mut self.store.heap_regions[heap_pos as usize].values;
 
                                    debug_assert!(values.is_empty());
 
                                    values.resize(length as usize, Value::UInt8(0));
 
                                    cur_frame.expr_values.push_back(Value::Message(heap_pos));
 
                                },
 
                                Method::Length => {
 
                                    let value = cur_frame.expr_values.pop_front().unwrap();
 
                                    let value_heap_pos = value.get_heap_pos();
 
                                    let value = self.store.maybe_read_ref(&value);
 

	
 
                                    let heap_pos = match value {
 
                                        Value::Array(pos) => *pos,
 
                                        Value::String(pos) => *pos,
 
                                        _ => unreachable!("length(...) on {:?}", value),
 
                                    };
 

	
 
                                    let len = self.store.heap_regions[heap_pos as usize].values.len();
 

	
 
                                    // TODO: @PtrInt
 
                                    cur_frame.expr_values.push_back(Value::UInt32(len as u32));
 
                                    self.store.drop_value(value_heap_pos);
 
                                },
 
                                Method::Assert => {
 
                                    let value = cur_frame.expr_values.pop_front().unwrap();
 
                                    let value = self.store.maybe_read_ref(&value).clone();
 
                                    if !value.as_bool() {
 
                                        return Ok(EvalContinuation::BranchInconsistent)
 
                                    }
 
                                },
 
                                Method::Print => {
 
                                    // Convert the runtime-variant of a string
 
                                    // into an actual string.
 
                                    let value = cur_frame.expr_values.pop_front().unwrap();
 
                                    let value_heap_pos = value.as_string();
 
                                    let elements = &self.store.heap_regions[value_heap_pos as usize].values;
 

	
 
                                    let mut message = String::with_capacity(elements.len());
 
                                    for element in elements {
 
                                        message.push(element.as_char());
 
                                    }
 

	
 
                                    // Drop the heap-allocated value from the
 
                                    // store
 
                                    self.store.drop_heap_pos(value_heap_pos);
 
                                    println!("{}", message);
 
                                },
 
                                Method::SelectStart => {
 
                                    let num_cases = self.store.maybe_read_ref(&cur_frame.expr_values.pop_front().unwrap()).as_uint32();
 
                                    let num_ports = self.store.maybe_read_ref(&cur_frame.expr_values.pop_front().unwrap()).as_uint32();
 
                                    if !ctx.select_start(num_cases, num_ports) {
 
                                        return Ok(EvalContinuation::SelectStart(num_cases, num_ports))
 
                                    }
 

	
 
                                    return Ok(EvalContinuation::SelectStart(num_cases, num_ports));
 
                                },
 
                                Method::SelectRegisterCasePort => {
 
                                    let case_index = self.store.maybe_read_ref(&cur_frame.expr_values.pop_front().unwrap()).as_uint32();
 
                                    let port_index = self.store.maybe_read_ref(&cur_frame.expr_values.pop_front().unwrap()).as_uint32();
 
                                    let port_value = self.store.maybe_read_ref(&cur_frame.expr_values.pop_front().unwrap()).as_port_id();
 

	
 
                                    if !ctx.performed_select_start() {
 
                                        return Ok(EvalContinuation::SelectRegisterPort(case_index, port_index, port_value));
 
                                    }
 
                                    return Ok(EvalContinuation::SelectRegisterPort(case_index, port_index, port_value));
 
                                },
 
                                Method::SelectWait => {
 
                                    match ctx.performed_select_wait() {
 
                                        Some(select_index) => {
 
                                            cur_frame.expr_values.push_back(Value::UInt32(select_index));
 
                                        },
 
                                        None => return Ok(EvalContinuation::SelectWait),
 
                                        None => {
 
                                            cur_frame.expr_stack.push_back(ExprInstruction::EvalExpr(expr.this.upcast()));
 
                                            return Ok(EvalContinuation::SelectWait)
 
                                        },
 
                                    }
 
                                },
 
                                Method::UserComponent => {
 
                                    // This is actually handled by the evaluation
 
                                    // of the statement.
 
                                    debug_assert_eq!(heap[expr.procedure].parameters.len(), cur_frame.expr_values.len());
 
                                    debug_assert_eq!(heap[cur_frame.position].as_new().expression, expr.this)
 
                                },
 
                                Method::UserFunction => {
 
                                    // Push a new frame. Note that all expressions have
 
                                    // been pushed to the front, so they're in the order
 
                                    // of the definition.
 
                                    let num_args = expr.arguments.len();
 

	
 
                                    // Determine stack boundaries
 
                                    let cur_stack_boundary = self.store.cur_stack_boundary;
 
                                    let new_stack_boundary = self.store.stack.len();
 

	
 
                                    // Push new boundary and function arguments for new frame
 
                                    self.store.stack.push(Value::PrevStackBoundary(cur_stack_boundary as isize));
 
                                    for _ in 0..num_args {
 
                                        let argument = self.store.read_take_ownership(cur_frame.expr_values.pop_front().unwrap());
 
                                        self.store.stack.push(argument);
 
                                    }
 

	
 
                                    // Determine the monomorph index of the function we're calling
 
                                    let mono_data = &heap[cur_frame.definition].monomorphs[cur_frame.monomorph_index];
 
                                    let (type_id, monomorph_index) = mono_data.expr_info[expr.type_index as usize].variant.as_procedure();
 

	
 
                                    // Push the new frame and reserve its stack size
 
                                    let new_frame = Frame::new(heap, expr.procedure, type_id, monomorph_index);
 
                                    let new_stack_size = new_frame.max_stack_size;
 
                                    self.frames.push(new_frame);
 
                                    self.store.cur_stack_boundary = new_stack_boundary;
 
                                    self.store.reserve_stack(new_stack_size);
 

	
 
                                    // To simplify the logic a little bit we will now
 
                                    // return and ask our caller to call us again
 
                                    return Ok(EvalContinuation::Stepping);
 
                                }
 
                            }
 
                        },
 
                        Expression::Variable(expr) => {
 
                            let variable = &heap[expr.declaration.unwrap()];
 
                            let ref_value = if expr.used_as_binding_target {
 
                                Value::Binding(variable.unique_id_in_scope as StackPos)
 
                            } else {
 
                                Value::Ref(ValueId::Stack(variable.unique_id_in_scope as StackPos))
 
                            };
 
                            cur_frame.expr_values.push_back(ref_value);
 
                        }
 
                    }
 
                }
 
            }
 
        }
 

	
 
        debug_log!("Frame [{:?}] at {:?}", cur_frame.definition, cur_frame.position);
 
        if debug_enabled!() {
 
            debug_log!("Expression value stack (size = {}):", cur_frame.expr_values.len());
 
            for (_stack_idx, _stack_val) in cur_frame.expr_values.iter().enumerate() {
 
                debug_log!("  [{:03}] {:?}", _stack_idx, _stack_val);
 
            }
 

	
 
            debug_log!("Stack (size = {}):", self.store.stack.len());
 
            for (_stack_idx, _stack_val) in self.store.stack.iter().enumerate() {
 
                debug_log!("  [{:03}] {:?}", _stack_idx, _stack_val);
 
            }
 

	
 
            debug_log!("Heap:");
 
            for (_heap_idx, _heap_region) in self.store.heap_regions.iter().enumerate() {
 
                let _is_free = self.store.free_regions.iter().any(|idx| *idx as usize == _heap_idx);
 
                debug_log!("  [{:03}] in_use: {}, len: {}, vals: {:?}", _heap_idx, !_is_free, _heap_region.values.len(), &_heap_region.values);
 
            }
 
        }
 
        // No (more) expressions to evaluate. So evaluate statement (that may
 
        // depend on the result on the last evaluated expression(s))
 
        let stmt = &heap[cur_frame.position];
 
        let return_value = match stmt {
 
            Statement::Block(stmt) => {
 
                debug_assert!(stmt.statements.is_empty() || stmt.next == stmt.statements[0]);
 
                cur_frame.position = stmt.next;
 
                Ok(EvalContinuation::Stepping)
 
            },
 
            Statement::EndBlock(stmt) => {
 
                let block = &heap[stmt.start_block];
 
                let scope = &heap[block.scope];
 
                self.store.clear_stack(scope.first_unique_id_in_scope as usize);
 
                cur_frame.position = stmt.next;
 

	
 
                Ok(EvalContinuation::Stepping)
 
            },
 
            Statement::Local(stmt) => {
 
                match stmt {
 
                    LocalStatement::Memory(stmt) => {
 
                        dbg_code!({
 
                            let variable = &heap[stmt.variable];
 
                            debug_assert!(match self.store.read_ref(ValueId::Stack(variable.unique_id_in_scope as u32)) {
 
                                Value::Unassigned => false,
 
                                _ => true,
 
                            });
 
                        });
 

	
 
                        cur_frame.position = stmt.next;
 
                        Ok(EvalContinuation::Stepping)
 
                    },
 
                    LocalStatement::Channel(stmt) => {
 
                        // Need to create a new channel by requesting it from
 
                        // the runtime.
 
                        match ctx.created_channel() {
 
                            None => {
 
                                // No channel is pending. So request one
 
                                    Ok(EvalContinuation::NewChannel)
 
                            },
 
                            Some((put_port, get_port)) => {
 
                                self.store.write(ValueId::Stack(heap[stmt.from].unique_id_in_scope as u32), put_port);
 
                                self.store.write(ValueId::Stack(heap[stmt.to].unique_id_in_scope as u32), get_port);
 
                                cur_frame.position = stmt.next;
 
                                Ok(EvalContinuation::Stepping)
 
                            }
 
                        }
 
                    }
 
                }
 
            },
 
            Statement::Labeled(stmt) => {
 
                cur_frame.position = stmt.body;
 

	
 
                Ok(EvalContinuation::Stepping)
 
            },
 
            Statement::If(stmt) => {
 
                debug_assert_eq!(cur_frame.expr_values.len(), 1, "expected one expr value for if statement");
 
                let test_value = cur_frame.expr_values.pop_back().unwrap();
 
                let test_value = self.store.maybe_read_ref(&test_value).as_bool();
 
                if test_value {
 
                    cur_frame.position = stmt.true_case.body;
 
                } else if let Some(false_body) = stmt.false_case {
 
                    cur_frame.position = false_body.body;
 
                } else {
 
                    // Not true, and no false body
 
                    cur_frame.position = stmt.end_if.upcast();
 
                }
 

	
 
                Ok(EvalContinuation::Stepping)
 
            },
 
            Statement::EndIf(stmt) => {
 
                cur_frame.position = stmt.next;
 
                let if_stmt = &heap[stmt.start_if];
 
                debug_assert_eq!(
 
                    heap[if_stmt.true_case.scope].first_unique_id_in_scope,
 
                    heap[if_stmt.false_case.unwrap_or(if_stmt.true_case).scope].first_unique_id_in_scope,
 
                );
 
                let scope = &heap[if_stmt.true_case.scope];
 
                self.store.clear_stack(scope.first_unique_id_in_scope as usize);
 
                Ok(EvalContinuation::Stepping)
 
            },
 
            Statement::While(stmt) => {
 
                debug_assert_eq!(cur_frame.expr_values.len(), 1, "expected one expr value for while statement");
 
                let test_value = cur_frame.expr_values.pop_back().unwrap();
 
                let test_value = self.store.maybe_read_ref(&test_value).as_bool();
 
                if test_value {
 
                    cur_frame.position = stmt.body;
 
                } else {
 
                    cur_frame.position = stmt.end_while.upcast();
 
                }
 

	
 
                Ok(EvalContinuation::Stepping)
 
            },
 
            Statement::EndWhile(stmt) => {
 
                cur_frame.position = stmt.next;
 
                let start_while = &heap[stmt.start_while];
 
                let scope = &heap[start_while.scope];
 
                self.store.clear_stack(scope.first_unique_id_in_scope as usize);
 
                Ok(EvalContinuation::Stepping)
 
            },
 
            Statement::Break(stmt) => {
 
                cur_frame.position = stmt.target.upcast();
 

	
 
                Ok(EvalContinuation::Stepping)
 
            },
 
            Statement::Continue(stmt) => {
 
                cur_frame.position = stmt.target.upcast();
 

	
 
                Ok(EvalContinuation::Stepping)
 
            },
 
            Statement::Synchronous(stmt) => {
 
                cur_frame.position = stmt.body;
 

	
 
                Ok(EvalContinuation::SyncBlockStart)
 
            },
 
            Statement::EndSynchronous(stmt) => {
 
                cur_frame.position = stmt.next;
 
                let start_synchronous = &heap[stmt.start_sync];
 
                let scope = &heap[start_synchronous.scope];
src/protocol/mod.rs
Show inline comments
 
@@ -26,234 +26,232 @@ pub struct Module {
 
}
 
/// Description of a protocol object, used to configure new connectors.
 
#[repr(C)]
 
pub struct ProtocolDescription {
 
    pub(crate) modules: Vec<Module>,
 
    pub(crate) heap: Heap,
 
    pub(crate) types: TypeTable,
 
    pub(crate) pool: Mutex<StringPool>,
 
}
 
#[derive(Debug, Clone)]
 
pub(crate) struct ComponentState {
 
    pub(crate) prompt: Prompt,
 
}
 

	
 
#[derive(Debug)]
 
pub enum ComponentCreationError {
 
    ModuleDoesntExist,
 
    DefinitionDoesntExist,
 
    DefinitionNotComponent,
 
    InvalidNumArguments,
 
    InvalidArgumentType(usize),
 
    UnownedPort,
 
    InSync,
 
}
 

	
 
impl ProtocolDescription {
 
    pub fn parse(buffer: &[u8]) -> Result<Self, String> {
 
        let source = InputSource::new(String::new(), Vec::from(buffer));
 
        let mut parser = Parser::new();
 
        parser.feed(source).expect("failed to feed source");
 
        
 
        if let Err(err) = parser.parse() {
 
            println!("ERROR:\n{}", err);
 
            return Err(format!("{}", err))
 
        }
 

	
 
        debug_assert_eq!(parser.modules.len(), 1, "only supporting one module here for now");
 
        let modules: Vec<Module> = parser.modules.into_iter()
 
            .map(|module| Module{
 
                source: module.source,
 
                root_id: module.root_id,
 
                name: module.name.map(|(_, name)| name)
 
            })
 
            .collect();
 

	
 
        return Ok(ProtocolDescription {
 
            modules,
 
            heap: parser.heap,
 
            types: parser.type_table,
 
            pool: Mutex::new(parser.string_pool),
 
        });
 
    }
 

	
 
    pub(crate) fn new_component(
 
        &self, module_name: &[u8], identifier: &[u8], arguments: ValueGroup
 
    ) -> Result<Prompt, ComponentCreationError> {
 
        // Find the module in which the definition can be found
 
        let module_root = self.lookup_module_root(module_name);
 
        if module_root.is_none() {
 
            return Err(ComponentCreationError::ModuleDoesntExist);
 
        }
 
        let module_root = module_root.unwrap();
 

	
 
        let root = &self.heap[module_root];
 
        let definition_id = root.get_definition_ident(&self.heap, identifier);
 
        if definition_id.is_none() {
 
            return Err(ComponentCreationError::DefinitionDoesntExist);
 
        }
 
        let definition_id = definition_id.unwrap();
 

	
 
        let ast_definition = &self.heap[definition_id];
 
        if !ast_definition.is_procedure() {
 
            return Err(ComponentCreationError::DefinitionNotComponent);
 
        }
 

	
 
        // Make sure that the types of the provided value group matches that of
 
        // the expected types.
 
        let ast_definition = ast_definition.as_procedure();
 
        if !ast_definition.poly_vars.is_empty() || ast_definition.kind == ProcedureKind::Function {
 
            return Err(ComponentCreationError::DefinitionNotComponent);
 
        }
 

	
 
        // - check number of arguments by retrieving the one instantiated
 
        //   monomorph
 
        let concrete_type = ConcreteType{ parts: vec![ConcreteTypePart::Component(ast_definition.this, 0)] };
 
        let procedure_type_id = self.types.get_procedure_monomorph_type_id(&definition_id, &concrete_type.parts).unwrap();
 
        let procedure_monomorph_index = self.types.get_monomorph(procedure_type_id).variant.as_procedure().monomorph_index;
 
        let monomorph_info = &ast_definition.monomorphs[procedure_monomorph_index as usize];
 
        if monomorph_info.argument_types.len() != arguments.values.len() {
 
            return Err(ComponentCreationError::InvalidNumArguments);
 
        }
 

	
 
        // - for each argument try to make sure the types match
 
        for arg_idx in 0..arguments.values.len() {
 
            let expected_type_id = monomorph_info.argument_types[arg_idx];
 
            let expected_type = &self.types.get_monomorph(expected_type_id).concrete_type;
 
            let provided_value = &arguments.values[arg_idx];
 
            if !self.verify_same_type(expected_type, 0, &arguments, provided_value) {
 
                return Err(ComponentCreationError::InvalidArgumentType(arg_idx));
 
            }
 
        }
 

	
 
        // By now we're sure that all of the arguments are correct. So create
 
        // the connector.
 
        return Ok(Prompt::new(&self.types, &self.heap, ast_definition.this, procedure_type_id, arguments));
 
    }
 

	
 
    fn lookup_module_root(&self, module_name: &[u8]) -> Option<RootId> {
 
        for module in self.modules.iter() {
 
            match &module.name {
 
                Some(name) => if name.as_bytes() == module_name {
 
                    return Some(module.root_id);
 
                },
 
                None => if module_name.is_empty() {
 
                    return Some(module.root_id);
 
                }
 
            }
 
        }
 

	
 
        return None;
 
    }
 

	
 
    fn verify_same_type(&self, expected: &ConcreteType, expected_idx: usize, arguments: &ValueGroup, argument: &Value) -> bool {
 
        use ConcreteTypePart as CTP;
 

	
 
        match &expected.parts[expected_idx] {
 
            CTP::Void | CTP::Message | CTP::Slice | CTP::Pointer | CTP::Function(_, _) | CTP::Component(_, _) => unreachable!(),
 
            CTP::Bool => if let Value::Bool(_) = argument { true } else { false },
 
            CTP::UInt8 => if let Value::UInt8(_) = argument { true } else { false },
 
            CTP::UInt16 => if let Value::UInt16(_) = argument { true } else { false },
 
            CTP::UInt32 => if let Value::UInt32(_) = argument { true } else { false },
 
            CTP::UInt64 => if let Value::UInt64(_) = argument { true } else { false },
 
            CTP::SInt8 => if let Value::SInt8(_) = argument { true } else { false },
 
            CTP::SInt16 => if let Value::SInt16(_) = argument { true } else { false },
 
            CTP::SInt32 => if let Value::SInt32(_) = argument { true } else { false },
 
            CTP::SInt64 => if let Value::SInt64(_) = argument { true } else { false },
 
            CTP::Character => if let Value::Char(_) = argument { true } else { false },
 
            CTP::String => {
 
                // Match outer string type and embedded character types
 
                if let Value::String(heap_pos) = argument {
 
                    for element in &arguments.regions[*heap_pos as usize] {
 
                        if let Value::Char(_) = element {} else {
 
                            return false;
 
                        }
 
                    }
 
                } else {
 
                    return false;
 
                }
 

	
 
                return true;
 
            },
 
            CTP::Array => {
 
                if let Value::Array(heap_pos) = argument {
 
                    let heap_pos = *heap_pos;
 
                    for element in &arguments.regions[heap_pos as usize] {
 
                        if !self.verify_same_type(expected, expected_idx + 1, arguments, element) {
 
                            return false;
 
                        }
 
                    }
 
                    return true;
 
                } else {
 
                    return false;
 
                }
 
            },
 
            CTP::Input => if let Value::Input(_) = argument { true } else { false },
 
            CTP::Output => if let Value::Output(_) = argument { true } else { false },
 
            CTP::Tuple(_) => todo!("implement full type checking on user-supplied arguments"),
 
            CTP::Instance(definition_id, _num_embedded) => {
 
                let definition = self.types.get_base_definition(definition_id).unwrap();
 
                match &definition.definition {
 
                    DefinedTypeVariant::Enum(definition) => {
 
                        if let Value::Enum(variant_value) = argument {
 
                            let is_valid = definition.variants.iter()
 
                                .any(|v| v.value == *variant_value);
 
                            return is_valid;
 
                        }
 
                    },
 
                    _ => todo!("implement full type checking on user-supplied arguments"),
 
                }
 

	
 
                return false;
 
            },
 
        }
 
    }
 
}
 

	
 
pub trait RunContext {
 
    fn performed_put(&mut self, port: PortId) -> bool;
 
    fn performed_get(&mut self, port: PortId) -> Option<ValueGroup>; // None if still waiting on message
 
    fn fires(&mut self, port: PortId) -> Option<Value>; // None if not yet branched
 
    fn performed_fork(&mut self) -> Option<bool>; // None if not yet forked
 
    fn created_channel(&mut self) -> Option<(Value, Value)>; // None if not yet prepared
 
    fn performed_select_start(&mut self) -> bool; // true if performed
 
    fn performed_select_register_port(&mut self) -> bool; // true if registered
 
    fn performed_select_wait(&mut self) -> Option<u32>; // None if not yet notified runtime of select blocker
 
}
 

	
 
pub struct ProtocolDescriptionBuilder {
 
    parser: Parser,
 
}
 

	
 
impl ProtocolDescriptionBuilder {
 
    pub fn new() -> Self {
 
        return Self{
 
            parser: Parser::new(),
 
        }
 
    }
 

	
 
    pub fn add(&mut self, filename: String, buffer: Vec<u8>) -> Result<(), ParseError> {
 
        let input = InputSource::new(filename, buffer);
 
        self.parser.feed(input)?;
 

	
 
        return Ok(())
 
    }
 

	
 
    pub fn compile(mut self) -> Result<ProtocolDescription, ParseError> {
 
        self.parser.parse()?;
 

	
 
        let modules: Vec<Module> = self.parser.modules.into_iter()
 
            .map(|module| Module{
 
                source: module.source,
 
                root_id: module.root_id,
 
                name: module.name.map(|(_, name)| name)
 
            })
 
            .collect();
 

	
 
        return Ok(ProtocolDescription {
 
            modules,
 
            heap: self.parser.heap,
 
            types: self.parser.type_table,
 
            pool: Mutex::new(self.parser.string_pool),
 
        });
 
    }
 
}
src/protocol/parser/pass_rewriting.rs
Show inline comments
 
use crate::collections::*;
 
use crate::protocol::*;
 

	
 
use super::visitor::*;
 

	
 
pub(crate) struct PassRewriting {
 
    current_scope: ScopeId,
 
    current_procedure_id: ProcedureDefinitionId,
 
    definition_buffer: ScopedBuffer<DefinitionId>,
 
    statement_buffer: ScopedBuffer<StatementId>,
 
    call_expr_buffer: ScopedBuffer<CallExpressionId>,
 
    expression_buffer: ScopedBuffer<ExpressionId>,
 
    scope_buffer: ScopedBuffer<ScopeId>,
 
}
 

	
 
impl PassRewriting {
 
    pub(crate) fn new() -> Self {
 
        Self{
 
            current_scope: ScopeId::new_invalid(),
 
            current_procedure_id: ProcedureDefinitionId::new_invalid(),
 
            definition_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_LARGE),
 
            statement_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL),
 
            call_expr_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL),
 
            expression_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL),
 
            scope_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL),
 
        }
 
    }
 
}
 

	
 
impl Visitor for PassRewriting {
 
    fn visit_module(&mut self, ctx: &mut Ctx) -> VisitorResult {
 
        let module = ctx.module();
 
        debug_assert_eq!(module.phase, ModuleCompilationPhase::Typed);
 

	
 
        let root_id = module.root_id;
 
        let root = &ctx.heap[root_id];
 
        let definition_section = self.definition_buffer.start_section_initialized(&root.definitions);
 
        for definition_index in 0..definition_section.len() {
 
            let definition_id = definition_section[definition_index];
 
            self.visit_definition(ctx, definition_id)?;
 
        }
 

	
 
        definition_section.forget();
 
        ctx.module_mut().phase = ModuleCompilationPhase::Rewritten;
 
        return Ok(())
 
    }
 

	
 
    // --- Visiting procedures
 

	
 
    fn visit_procedure_definition(&mut self, ctx: &mut Ctx, id: ProcedureDefinitionId) -> VisitorResult {
 
        let definition = &ctx.heap[id];
 
        let body_id = definition.body;
 
        self.current_scope = definition.scope;
 
        self.current_procedure_id = id;
 
        return self.visit_block_stmt(ctx, body_id);
 
    }
 

	
 
    // --- Visiting statements (that are not the select statement)
 

	
 
    fn visit_block_stmt(&mut self, ctx: &mut Ctx, id: BlockStatementId) -> VisitorResult {
 
        let block_stmt = &ctx.heap[id];
 
        let stmt_section = self.statement_buffer.start_section_initialized(&block_stmt.statements);
 

	
 
        self.current_scope = block_stmt.scope;
 
        for stmt_idx in 0..stmt_section.len() {
 
            self.visit_stmt(ctx, stmt_section[stmt_idx])?;
 
        }
 

	
 
        stmt_section.forget();
 
        return Ok(())
 
    }
 

	
 
    fn visit_labeled_stmt(&mut self, ctx: &mut Ctx, id: LabeledStatementId) -> VisitorResult {
 
        let labeled_stmt = &ctx.heap[id];
 
        let body_id = labeled_stmt.body;
 
        return self.visit_stmt(ctx, body_id);
 
    }
 

	
 
    fn visit_if_stmt(&mut self, ctx: &mut Ctx, id: IfStatementId) -> VisitorResult {
 
        let if_stmt = &ctx.heap[id];
 
        let true_case = if_stmt.true_case;
 
        let false_case = if_stmt.false_case;
 

	
 
        self.current_scope = true_case.scope;
 
        self.visit_stmt(ctx, true_case.body)?;
 
        if let Some(false_case) = false_case {
 
            self.current_scope = false_case.scope;
 
            self.visit_stmt(ctx, false_case.body)?;
 
        }
 

	
 
        return Ok(())
 
    }
 

	
 
    fn visit_while_stmt(&mut self, ctx: &mut Ctx, id: WhileStatementId) -> VisitorResult {
 
        let while_stmt = &ctx.heap[id];
 
        let body_id = while_stmt.body;
 
        self.current_scope = while_stmt.scope;
 
        return self.visit_stmt(ctx, body_id);
 
    }
 

	
 
    fn visit_synchronous_stmt(&mut self, ctx: &mut Ctx, id: SynchronousStatementId) -> VisitorResult {
 
        let sync_stmt = &ctx.heap[id];
 
        let body_id = sync_stmt.body;
 
        self.current_scope = sync_stmt.scope;
 
        return self.visit_stmt(ctx, body_id);
 
    }
 

	
 
    // --- Visiting the select statement
 

	
 
    fn visit_select_stmt(&mut self, ctx: &mut Ctx, id: SelectStatementId) -> VisitorResult {
 
        // Utility for the last stage of rewriting process. Note that caller
 
        // still needs to point the end of the if-statement to the end of the
 
        // replacement statement of the select statement.
 
        fn transform_select_case_code(
 
            ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId,
 
            select_id: SelectStatementId, case_index: usize,
 
            select_var_id: VariableId, select_var_type_id: TypeIdReference
 
        ) -> (IfStatementId, EndIfStatementId) {
 
        ) -> (IfStatementId, EndIfStatementId, ScopeId) {
 
            // Retrieve statement IDs associated with case
 
            let case = &ctx.heap[select_id].cases[case_index];
 
            let case_guard_id = case.guard;
 
            let case_body_id = case.body;
 
            let case_scope_id = case.scope;
 

	
 
            // Create the if-statement for the result of the select statement
 
            let compare_expr_id = create_ast_equality_comparison_expr(ctx, containing_procedure_id, select_var_id, select_var_type_id, case_index as u64);
 
            let true_case = IfStatementCase{
 
                body: case_guard_id, // which is linked up to the body
 
                scope: case_scope_id,
 
            };
 
            let (if_stmt_id, end_if_stmt_id) = create_ast_if_stmt(ctx, compare_expr_id.upcast(), true_case, None);
 

	
 
            // Link up body statement to end-if
 
            set_ast_statement_next(ctx, case_body_id, end_if_stmt_id.upcast());
 

	
 
            return (if_stmt_id, end_if_stmt_id)
 
            return (if_stmt_id, end_if_stmt_id, case_scope_id);
 
        }
 

	
 
        // Precreate the block that will end up containing all of the
 
        // transformed statements. Also precreate the scope associated with it
 
        let (outer_block_id, outer_end_block_id, outer_scope_id) =
 
            create_ast_block_stmt(ctx, Vec::new());
 

	
 
        // The "select" and the "end select" statement will act like trampolines
 
        // that jump to the replacement block. So set the child/parent
 
        // relationship already.
 
        // --- for the statements
 
        let select_stmt = &mut ctx.heap[id];
 
        select_stmt.next = outer_block_id.upcast();
 
        let end_select_stmt_id = select_stmt.end_select;
 
        let select_stmt_relative_pos = select_stmt.relative_pos_in_parent;
 

	
 
        let outer_end_block_stmt = &mut ctx.heap[outer_end_block_id];
 
        outer_end_block_stmt.next = end_select_stmt_id.upcast();
 

	
 
        // --- for the scopes
 
        link_new_child_to_existing_parent_scope(ctx, &mut self.scope_buffer, self.current_scope, outer_scope_id, select_stmt_relative_pos);
 

	
 
        // Create statements that will create temporary variables for all of the
 
        // ports passed to the "get" calls in the select case guards.
 
        let select_stmt = &ctx.heap[id];
 
        let total_num_cases = select_stmt.cases.len();
 
        let mut total_num_ports = 0;
 
        let end_select_stmt_id = select_stmt.end_select;
 
        let _end_select = &ctx.heap[end_select_stmt_id];
 

	
 
        // Put heap IDs into temporary buffers to handle borrowing rules
 
        let mut call_id_section = self.call_expr_buffer.start_section();
 
        let mut expr_id_section = self.expression_buffer.start_section();
 

	
 
        for case in select_stmt.cases.iter() {
 
            total_num_ports += case.involved_ports.len();
 
            for (call_id, expr_id) in case.involved_ports.iter().copied() {
 
                call_id_section.push(call_id);
 
                expr_id_section.push(expr_id);
 
            }
 
        }
 

	
 
        // Transform all of the call expressions by takings its argument (the
 
        // port from which we `get`) and turning it into a temporary variable.
 
        let mut transformed_stmts = Vec::with_capacity(total_num_ports); // TODO: Recompute this preallocated length, put assert at the end
 
        let mut locals = Vec::with_capacity(total_num_ports);
 

	
 
        for port_var_idx in 0..call_id_section.len() {
 
            let get_call_expr_id = call_id_section[port_var_idx];
 
            let port_expr_id = expr_id_section[port_var_idx];
 
            let port_type_index = ctx.heap[port_expr_id].type_index();
 
            let port_type_ref = TypeIdReference::IndirectSameAsExpr(port_type_index);
 

	
 
            // Move the port expression such that it gets assigned to a temporary variable
 
            let variable_id = create_ast_variable(ctx, outer_scope_id);
 
            let variable_decl_stmt_id = create_ast_variable_declaration_stmt(ctx, self.current_procedure_id, variable_id, port_type_ref, port_expr_id);
 

	
 
            // Replace the original port expression in the call with a reference
 
            // to the replacement variable
 
            let variable_expr_id = create_ast_variable_expr(ctx, self.current_procedure_id, variable_id, port_type_ref);
 
            let call_expr = &mut ctx.heap[get_call_expr_id];
 
            call_expr.arguments[0] = variable_expr_id.upcast();
 

	
 
            transformed_stmts.push(variable_decl_stmt_id.upcast().upcast());
 
            locals.push((variable_id, port_type_ref));
 
        }
 

	
 
        // Insert runtime calls that facilitate the semantics of the select
 
        // block.
 

	
 
        // Create the call that indicates the start of the select block
 
        {
 
            let num_cases_expression_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, total_num_cases as u64);
 
            let num_ports_expression_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, total_num_ports as u64);
 
            let num_cases_expression_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, total_num_cases as u64, ctx.arch.uint32_type_id);
 
            let num_ports_expression_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, total_num_ports as u64, ctx.arch.uint32_type_id);
 
            let arguments = vec![
 
                num_cases_expression_id.upcast(),
 
                num_ports_expression_id.upcast()
 
            ];
 

	
 
            let call_expression_id = create_ast_call_expr(ctx, self.current_procedure_id, Method::SelectStart, &mut self.expression_buffer, arguments);
 
            let call_statement_id = create_ast_expression_stmt(ctx, call_expression_id.upcast());
 

	
 
            transformed_stmts.push(call_statement_id.upcast());
 
        }
 

	
 
        // Create calls for each select case that will register the ports that
 
        // we are waiting on at the runtime.
 
        {
 
            let mut total_port_index = 0;
 
            for case_index in 0..total_num_cases {
 
                let case = &ctx.heap[id].cases[case_index];
 
                let case_num_ports = case.involved_ports.len();
 

	
 
                for case_port_index in 0..case_num_ports {
 
                    // Arguments to runtime call
 
                    let (port_variable_id, port_variable_type) = locals[total_port_index]; // so far this variable contains the temporary variables for the port expressions
 
                    let case_index_expr_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, case_index as u64);
 
                    let port_index_expr_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, case_port_index as u64);
 
                    let case_index_expr_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, case_index as u64, ctx.arch.uint32_type_id);
 
                    let port_index_expr_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, case_port_index as u64, ctx.arch.uint32_type_id);
 
                    let port_variable_expr_id = create_ast_variable_expr(ctx, self.current_procedure_id, port_variable_id, port_variable_type);
 
                    let runtime_call_arguments = vec![
 
                        case_index_expr_id.upcast(),
 
                        port_index_expr_id.upcast(),
 
                        port_variable_expr_id.upcast()
 
                    ];
 

	
 
                    // Create runtime call, then store it
 
                    let runtime_call_expr_id = create_ast_call_expr(ctx, self.current_procedure_id, Method::SelectRegisterCasePort, &mut self.expression_buffer, runtime_call_arguments);
 
                    let runtime_call_stmt_id = create_ast_expression_stmt(ctx, runtime_call_expr_id.upcast());
 

	
 
                    transformed_stmts.push(runtime_call_stmt_id.upcast());
 

	
 
                    total_port_index += 1;
 
                }
 
            }
 
        }
 

	
 
        // Create the variable that will hold the result of a completed select
 
        // block. Then create the runtime call that will produce this result
 
        let select_variable_id = create_ast_variable(ctx, outer_scope_id);
 
        let select_variable_type = TypeIdReference::DirectTypeId(ctx.arch.uint32_type_id);
 
        locals.push((select_variable_id, select_variable_type));
 

	
 
        {
 
            let runtime_call_expr_id = create_ast_call_expr(ctx, self.current_procedure_id, Method::SelectWait, &mut self.expression_buffer, Vec::new());
 
            let variable_stmt_id = create_ast_variable_declaration_stmt(ctx, self.current_procedure_id, select_variable_id, select_variable_type, runtime_call_expr_id.upcast());
 
            transformed_stmts.push(variable_stmt_id.upcast().upcast());
 
        }
 

	
 
        call_id_section.forget();
 
        expr_id_section.forget();
 

	
 
        // Now we transform each of the select block case's guard and code into
 
        // a chained if-else statement.
 
        let mut relative_pos = transformed_stmts.len() as i32;
 
        if total_num_cases > 0 {
 
            let (if_stmt_id, end_if_stmt_id) = transform_select_case_code(ctx, self.current_procedure_id, id, 0, select_variable_id, select_variable_type);
 
            let (if_stmt_id, end_if_stmt_id, scope_id) = transform_select_case_code(ctx, self.current_procedure_id, id, 0, select_variable_id, select_variable_type);
 
            link_existing_child_to_new_parent_scope(ctx, &mut self.scope_buffer, outer_scope_id, scope_id, relative_pos);
 
            let first_end_if_stmt = &mut ctx.heap[end_if_stmt_id];
 
            first_end_if_stmt.next = outer_end_block_id.upcast();
 

	
 
            let mut last_if_stmt_id = if_stmt_id;
 
            let mut last_end_if_stmt_id = end_if_stmt_id;
 
            let mut last_parent_scope_id = outer_scope_id;
 
            let mut last_relative_pos = transformed_stmts.len() as i32 + 1;
 
            transformed_stmts.push(last_if_stmt_id.upcast());
 

	
 
            for case_index in 1..total_num_cases {
 
                let (if_stmt_id, end_if_stmt_id) = transform_select_case_code(ctx, self.current_procedure_id, id, case_index, select_variable_id, select_variable_type);
 
                let (if_stmt_id, end_if_stmt_id, scope_id) = transform_select_case_code(ctx, self.current_procedure_id, id, case_index, select_variable_id, select_variable_type);
 
                let false_case_scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::If(last_if_stmt_id, false)));
 
                link_existing_child_to_new_parent_scope(ctx, &mut self.scope_buffer, false_case_scope_id, scope_id, 0);
 
                link_new_child_to_existing_parent_scope(ctx, &mut self.scope_buffer, last_parent_scope_id, false_case_scope_id, last_relative_pos);
 
                set_ast_if_statement_false_body(ctx, last_if_stmt_id, last_end_if_stmt_id, IfStatementCase{ body: if_stmt_id.upcast(), scope: false_case_scope_id });
 

	
 
                let end_if_stmt = &mut ctx.heap[end_if_stmt_id];
 
                end_if_stmt.next = last_end_if_stmt_id.upcast();
 

	
 
                last_if_stmt_id = if_stmt_id;
 
                last_end_if_stmt_id = end_if_stmt_id;
 
                last_parent_scope_id = false_case_scope_id;
 
                last_relative_pos = 0;
 
            }
 
        }
 

	
 
        // Final steps: set the statements of the replacement block statement,
 
        // and link all of those statements together
 
        // link all of those statements together, and update the scopes.
 
        let first_stmt_id = transformed_stmts[0];
 
        let mut last_stmt_id = transformed_stmts[0];
 
        for stmt_id in transformed_stmts.iter().skip(1).copied() {
 
            set_ast_statement_next(ctx, last_stmt_id, stmt_id);
 
            last_stmt_id = stmt_id;
 
        }
 

	
 
        let outer_block_stmt = &mut ctx.heap[outer_block_id];
 
        outer_block_stmt.next = first_stmt_id;
 
        outer_block_stmt.statements = transformed_stmts;
 

	
 
        return Ok(())
 
    }
 
}
 

	
 
// -----------------------------------------------------------------------------
 
// Utilities to create compiler-generated AST nodes
 
// -----------------------------------------------------------------------------
 

	
 
#[derive(Clone, Copy)]
 
enum TypeIdReference {
 
    DirectTypeId(TypeId),
 
    IndirectSameAsExpr(i32), // by type index
 
}
 

	
 
fn create_ast_variable(ctx: &mut Ctx, scope_id: ScopeId) -> VariableId {
 
    let variable_id = ctx.heap.alloc_variable(|this| Variable{
 
        this,
 
        kind: VariableKind::Local,
 
        parser_type: ParserType{
 
            elements: Vec::new(),
 
            full_span: InputSpan::new(),
 
        },
 
        identifier: Identifier::new_empty(InputSpan::new()),
 
        relative_pos_in_parent: -1,
 
        unique_id_in_scope: -1,
 
    });
 
    let scope = &mut ctx.heap[scope_id];
 
    scope.variables.push(variable_id);
 

	
 
    return variable_id;
 
}
 

	
 
fn create_ast_variable_expr(ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, variable_id: VariableId, variable_type_id: TypeIdReference) -> VariableExpressionId {
 
    let variable_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, variable_type_id);
 
    return ctx.heap.alloc_variable_expression(|this| VariableExpression{
 
        this,
 
        identifier: Identifier::new_empty(InputSpan::new()),
 
        declaration: Some(variable_id),
 
        used_as_binding_target: false,
 
        parent: ExpressionParent::None,
 
        type_index: variable_type_index,
 
    });
 
}
 

	
 
fn create_ast_call_expr(ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, method: Method, buffer: &mut ScopedBuffer<ExpressionId>, arguments: Vec<ExpressionId>) -> CallExpressionId {
 
    let call_type_id = match method {
 
        Method::SelectStart => ctx.arch.void_type_id,
 
        Method::SelectRegisterCasePort => ctx.arch.void_type_id,
 
        Method::SelectWait => ctx.arch.uint32_type_id, // TODO: Not pretty, this. Pretty error prone
 
        _ => unreachable!(), // if this goes of, add the appropriate method here.
 
    };
 

	
 
    let expression_ids = buffer.start_section_initialized(&arguments);
 
    let call_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, TypeIdReference::DirectTypeId(call_type_id));
 
    let call_expression_id = ctx.heap.alloc_call_expression(|this| CallExpression{
 
        func_span: InputSpan::new(),
 
        this,
 
        full_span: InputSpan::new(),
 
        parser_type: ParserType{
 
            elements: Vec::new(),
 
            full_span: InputSpan::new(),
 
        },
 
        method,
 
        arguments,
 
        procedure: ProcedureDefinitionId::new_invalid(),
 
        parent: ExpressionParent::None,
 
        type_index: call_type_index,
 
    });
 

	
 
    for argument_index in 0..expression_ids.len() {
 
        let argument_id = expression_ids[argument_index];
 
        let argument_expr = &mut ctx.heap[argument_id];
 
        *argument_expr.parent_mut() = ExpressionParent::Expression(call_expression_id.upcast(), argument_index as u32);
 
    }
 

	
 
    return call_expression_id;
 
}
 

	
 
fn create_ast_literal_integer_expr(ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, unsigned_value: u64) -> LiteralExpressionId {
 
    let literal_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, TypeIdReference::DirectTypeId(ctx.arch.uint64_type_id));
 
fn create_ast_literal_integer_expr(ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, unsigned_value: u64, type_id: TypeId) -> LiteralExpressionId {
 
    let literal_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, TypeIdReference::DirectTypeId(type_id));
 
    return ctx.heap.alloc_literal_expression(|this| LiteralExpression{
 
        this,
 
        span: InputSpan::new(),
 
        value: Literal::Integer(LiteralInteger{
 
            unsigned_value,
 
            negated: false,
 
        }),
 
        parent: ExpressionParent::None,
 
        type_index: literal_type_index,
 
    });
 
}
 

	
 
fn create_ast_equality_comparison_expr(
 
    ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId,
 
    variable_id: VariableId, variable_type: TypeIdReference, value: u64
 
) -> BinaryExpressionId {
 
    let var_expr_id = create_ast_variable_expr(ctx, containing_procedure_id, variable_id, variable_type);
 
    let int_expr_id = create_ast_literal_integer_expr(ctx, containing_procedure_id, value);
 
    let int_expr_id = create_ast_literal_integer_expr(ctx, containing_procedure_id, value, ctx.arch.uint32_type_id);
 
    let cmp_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, TypeIdReference::DirectTypeId(ctx.arch.bool_type_id));
 
    let cmp_expr_id = ctx.heap.alloc_binary_expression(|this| BinaryExpression{
 
        this,
 
        operator_span: InputSpan::new(),
 
        full_span: InputSpan::new(),
 
        left: var_expr_id.upcast(),
 
        operation: BinaryOperator::Equality,
 
        right: int_expr_id.upcast(),
 
        parent: ExpressionParent::None,
 
        type_index: cmp_type_index,
 
    });
 

	
 
    let var_expr = &mut ctx.heap[var_expr_id];
 
    var_expr.parent = ExpressionParent::Expression(cmp_expr_id.upcast(), 0);
 
    let int_expr = &mut ctx.heap[int_expr_id];
 
    int_expr.parent = ExpressionParent::Expression(cmp_expr_id.upcast(), 1);
 

	
 
    return cmp_expr_id;
 
}
 

	
 
fn create_ast_expression_stmt(ctx: &mut Ctx, expression_id: ExpressionId) -> ExpressionStatementId {
 
    let statement_id = ctx.heap.alloc_expression_statement(|this| ExpressionStatement{
 
        this,
 
        span: InputSpan::new(),
 
        expression: expression_id,
 
        next: StatementId::new_invalid(),
 
    });
 

	
 
    let expression = &mut ctx.heap[expression_id];
 
    *expression.parent_mut() = ExpressionParent::ExpressionStmt(statement_id);
 

	
 
    return statement_id;
 
}
 

	
 
fn create_ast_variable_declaration_stmt(
 
    ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId,
 
    variable_id: VariableId, variable_type: TypeIdReference, initial_value_expr_id: ExpressionId
 
) -> MemoryStatementId {
 
    // Create the assignment expression, assigning the initial value to the variable
 
    let variable_expr_id = create_ast_variable_expr(ctx, containing_procedure_id, variable_id, variable_type);
 
    let void_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, TypeIdReference::DirectTypeId(ctx.arch.void_type_id));
 
    let assignment_expr_id = ctx.heap.alloc_assignment_expression(|this| AssignmentExpression{
 
        this,
 
        operator_span: InputSpan::new(),
 
        full_span: InputSpan::new(),
 
        left: variable_expr_id.upcast(),
 
        operation: AssignmentOperator::Set,
 
        right: initial_value_expr_id,
 
        parent: ExpressionParent::None,
 
        type_index: -1,
 
        type_index: void_type_index,
 
    });
 

	
 
    // Create the memory statement
 
    let memory_stmt_id = ctx.heap.alloc_memory_statement(|this| MemoryStatement{
 
        this,
 
        span: InputSpan::new(),
 
        variable: variable_id,
 
        initial_expr: assignment_expr_id,
 
        next: StatementId::new_invalid(),
 
    });
 

	
 
    // Set all parents which we can access
 
    let variable_expr = &mut ctx.heap[variable_expr_id];
 
    variable_expr.parent = ExpressionParent::Expression(assignment_expr_id.upcast(), 0);
 
    let value_expr = &mut ctx.heap[initial_value_expr_id];
 
    *value_expr.parent_mut() = ExpressionParent::Expression(assignment_expr_id.upcast(), 1);
 
    let assignment_expr = &mut ctx.heap[assignment_expr_id];
 
    assignment_expr.parent = ExpressionParent::Memory(memory_stmt_id);
 

	
 
    return memory_stmt_id;
 
}
 

	
 
fn create_ast_block_stmt(ctx: &mut Ctx, statements: Vec<StatementId>) -> (BlockStatementId, EndBlockStatementId, ScopeId) {
 
    let block_stmt_id = ctx.heap.alloc_block_statement(|this| BlockStatement{
 
        this,
 
        span: InputSpan::new(),
 
        statements,
 
        end_block: EndBlockStatementId::new_invalid(),
 
        scope: ScopeId::new_invalid(),
 
        next: StatementId::new_invalid(),
 
    });
 
    let end_block_stmt_id = ctx.heap.alloc_end_block_statement(|this| EndBlockStatement{
 
        this,
 
        start_block: block_stmt_id,
 
        next: StatementId::new_invalid(),
 
    });
 
    let scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::Block(block_stmt_id)));
 

	
 
    let block_stmt = &mut ctx.heap[block_stmt_id];
 
    block_stmt.end_block = end_block_stmt_id;
 
    block_stmt.scope = scope_id;
 

	
 
    return (block_stmt_id, end_block_stmt_id, scope_id);
 
}
 

	
 
fn create_ast_if_stmt(ctx: &mut Ctx, condition_expression_id: ExpressionId, true_case: IfStatementCase, false_case: Option<IfStatementCase>) -> (IfStatementId, EndIfStatementId) {
 
    // Create if statement and the end-if statement
 
    let if_stmt_id = ctx.heap.alloc_if_statement(|this| IfStatement{
 
        this,
 
        span: InputSpan::new(),
 
        test: condition_expression_id,
 
        true_case,
 
        false_case,
 
        end_if: EndIfStatementId::new_invalid()
 
    });
 

	
 
    let end_if_stmt_id = ctx.heap.alloc_end_if_statement(|this| EndIfStatement{
 
        this,
 
        start_if: if_stmt_id,
 
        next: StatementId::new_invalid(),
 
    });
 

	
 
    // Link the statements up as much as we can
 
    let if_stmt = &mut ctx.heap[if_stmt_id];
 
    if_stmt.end_if = end_if_stmt_id;
 

	
 
    let condition_expr = &mut ctx.heap[condition_expression_id];
 
    *condition_expr.parent_mut() = ExpressionParent::If(if_stmt_id);
 

	
 

	
 

	
 
    return (if_stmt_id, end_if_stmt_id);
 
}
 

	
 
/// Sets the false body for a given
 
fn set_ast_if_statement_false_body(ctx: &mut Ctx, if_statement_id: IfStatementId, end_if_statement_id: EndIfStatementId, false_case: IfStatementCase) {
 
    // Point if-statement to "false body"
 
    let if_stmt = &mut ctx.heap[if_statement_id];
 
    debug_assert!(if_stmt.false_case.is_none()); // simplifies logic, not necessary
 
    if_stmt.false_case = Some(false_case);
 

	
 
    // Point end of false body to the end of the if statement
 
    set_ast_statement_next(ctx, false_case.body, end_if_statement_id.upcast());
 
}
 

	
 
/// Sets the specified AST statement's control flow such that it will be
 
/// followed by the target statement. This may seem obvious, but may imply that
 
/// a statement associated with, but different from, the source statement is
 
/// modified.
 
fn set_ast_statement_next(ctx: &mut Ctx, source_stmt_id: StatementId, target_stmt_id: StatementId) {
 
    let source_stmt = &mut ctx.heap[source_stmt_id];
 
    match source_stmt {
 
        Statement::Block(stmt) => {
 
            let end_id = stmt.end_block;
 
            ctx.heap[end_id].next = target_stmt_id
 
        },
 
        Statement::EndBlock(stmt) => stmt.next = target_stmt_id,
 
        Statement::Local(stmt) => {
 
            match stmt {
 
                LocalStatement::Memory(stmt) => stmt.next = target_stmt_id,
 
                LocalStatement::Channel(stmt) => stmt.next = target_stmt_id,
 
            }
 
        },
 
        Statement::Labeled(stmt) => {
 
            let body_id = stmt.body;
 
            set_ast_statement_next(ctx, body_id, target_stmt_id);
 
        },
 
        Statement::If(stmt) => {
 
            let end_id = stmt.end_if;
 
            ctx.heap[end_id].next = target_stmt_id;
 
        },
 
        Statement::EndIf(stmt) => stmt.next = target_stmt_id,
 
        Statement::While(stmt) => {
 
            let end_id = stmt.end_while;
 
            ctx.heap[end_id].next = target_stmt_id;
 
        },
 
        Statement::EndWhile(stmt) => stmt.next = target_stmt_id,
 

	
 
        Statement::Break(_stmt) => {},
 
        Statement::Continue(_stmt) => {},
 
        Statement::Synchronous(stmt) => {
 
            let end_id = stmt.end_sync;
 
            ctx.heap[end_id].next = target_stmt_id;
 
        },
 
        Statement::EndSynchronous(stmt) => {
 
            stmt.next = target_stmt_id;
 
        },
 
        Statement::Fork(_) | Statement::EndFork(_) => {
 
            todo!("remove fork from language");
 
        },
 
        Statement::Select(stmt) => {
 
            let end_id = stmt.end_select;
 
            ctx.heap[end_id].next = target_stmt_id;
 
        },
 
        Statement::EndSelect(stmt) => stmt.next = target_stmt_id,
 
        Statement::Return(_stmt) => {},
 
        Statement::Goto(_stmt) => {},
 
        Statement::New(stmt) => stmt.next = target_stmt_id,
 
        Statement::Expression(stmt) => stmt.next = target_stmt_id,
 
    }
 
}
 

	
 
/// Links a new scope to an existing scope as its child.
 
fn link_new_child_to_existing_parent_scope(ctx: &mut Ctx, scope_buffer: &mut ScopedBuffer<ScopeId>, parent_scope_id: ScopeId, child_scope_id: ScopeId, relative_pos_hint: i32) {
 
    let child_scope = &mut ctx.heap[child_scope_id];
 
    debug_assert!(child_scope.parent.is_none());
 

	
 
    child_scope.parent = Some(parent_scope_id);
 
    child_scope.relative_pos_in_parent = relative_pos_hint;
 

	
 
    add_child_scope_to_parent(ctx, scope_buffer, parent_scope_id, child_scope_id, relative_pos_hint);
 
}
 

	
 
/// Relinks an existing scope to a new scope as its child. Will also break the
 
/// link of the child scope's old parent.
 
fn link_existing_child_to_new_parent_scope(ctx: &mut Ctx, scope_buffer: &mut ScopedBuffer<ScopeId>, new_parent_scope_id: ScopeId, child_scope_id: ScopeId, new_relative_pos_in_parent: i32) {
 
    let child_scope = &mut ctx.heap[child_scope_id];
 
    let old_parent_scope_id = child_scope.parent.unwrap();
 
    child_scope.parent = Some(new_parent_scope_id);
 
    child_scope.relative_pos_in_parent = new_relative_pos_in_parent;
 

	
 
    // Remove from old parent
 
    let old_parent = &mut ctx.heap[old_parent_scope_id];
 
    let scope_index = old_parent.nested.iter()
 
        .position(|v| *v == child_scope_id)
 
        .unwrap();
 
    old_parent.nested.remove(scope_index);
 

	
 
    // Add to new parent
 
    add_child_scope_to_parent(ctx, scope_buffer, new_parent_scope_id, child_scope_id, new_relative_pos_in_parent);
 
}
 

	
 
/// Will add a child scope to a parent scope using the relative position hint.
 
fn add_child_scope_to_parent(ctx: &mut Ctx, scope_buffer: &mut ScopedBuffer<ScopeId>, parent_scope_id: ScopeId, child_scope_id: ScopeId, relative_pos_hint: i32) {
 
    let child_scope = &ctx.heap[child_scope_id];
 
    let parent_scope = &ctx.heap[parent_scope_id];
 

	
 
    let existing_scope_ids = scope_buffer.start_section_initialized(&child_scope.nested);
 
    let existing_scope_ids = scope_buffer.start_section_initialized(&parent_scope.nested);
 
    let mut insert_pos = existing_scope_ids.len();
 
    for index in 0..existing_scope_ids.len() {
 
        let existing_scope_id = existing_scope_ids[index];
 
        let existing_scope = &ctx.heap[existing_scope_id];
 
        if relative_pos_hint <= existing_scope.relative_pos_in_parent {
 
            insert_pos = index;
 
            break;
 
        }
 
    }
 
    existing_scope_ids.forget();
 

	
 
    let parent_scope = &mut ctx.heap[parent_scope_id];
 
    parent_scope.nested.insert(insert_pos, child_scope_id);
 
}
 

	
 
fn add_new_procedure_expression_type(ctx: &mut Ctx, procedure_id: ProcedureDefinitionId, type_id: TypeIdReference) -> i32 {
 
    let procedure = &mut ctx.heap[procedure_id];
 
    let type_index = procedure.monomorphs[0].expr_info.len();
 

	
 
    match type_id {
 
        TypeIdReference::DirectTypeId(type_id) => {
 
            for monomorph in procedure.monomorphs.iter_mut() {
 
                debug_assert_eq!(monomorph.expr_info.len(), type_index);
 
                monomorph.expr_info.push(ExpressionInfo{
 
                    type_id,
 
                    variant: ExpressionInfoVariant::Generic
 
                });
 
            }
 
        },
 
        TypeIdReference::IndirectSameAsExpr(source_type_index) => {
 
            for monomorph in procedure.monomorphs.iter_mut() {
 
                debug_assert_eq!(monomorph.expr_info.len(), type_index);
 
                let copied_expr_info = monomorph.expr_info[source_type_index as usize];
 
                monomorph.expr_info.push(copied_expr_info)
 
            }
 
        }
 
    }
 

	
 
    return type_index as i32;
 
}
 
}
 
\ No newline at end of file
src/protocol/parser/pass_typing.rs
Show inline comments
 
@@ -1920,389 +1920,384 @@ impl PassTyping {
 
            });
 

	
 
            var_data_index
 
        };
 

	
 
        let node = &mut self.infer_nodes[self_index];
 
        node.inference_rule = InferenceRule::VariableExpr(InferenceRuleVariableExpr{
 
            var_data_index,
 
        });
 

	
 
        self.parent_index = old_parent;
 
        self.progress_inference_rule_variable_expr(ctx, self_index)?;
 
        return Ok(self_index);
 
    }
 
}
 

	
 
// -----------------------------------------------------------------------------
 
// PassTyping - Type-inference progression
 
// -----------------------------------------------------------------------------
 

	
 
impl PassTyping {
 
    #[allow(dead_code)] // used when debug flag at the top of this file is true.
 
    fn debug_get_display_name(&self, ctx: &Ctx, node_index: InferNodeIndex) -> String {
 
        let expr_type = &self.infer_nodes[node_index].expr_type;
 
        expr_type.display_name(&ctx.heap)
 
    }
 

	
 
    fn resolve_types(&mut self, ctx: &mut Ctx, queue: &mut ResolveQueue) -> Result<(), ParseError> {
 
        // Keep inferring until we can no longer make any progress
 
        while !self.node_queued.is_empty() {
 
            while !self.node_queued.is_empty() {
 
                let node_index = self.node_queued.pop_front().unwrap();
 
                self.progress_inference_rule(ctx, node_index)?;
 
            }
 

	
 
            // Nothing is queued anymore. However we might have integer literals
 
            // whose type cannot be inferred. For convenience's sake we'll
 
            // infer these to be s32.
 
            for (infer_node_index, infer_node) in self.infer_nodes.iter_mut().enumerate() {
 
                let expr_type = &mut infer_node.expr_type;
 
                if !expr_type.is_done && expr_type.parts.len() == 1 && expr_type.parts[0] == InferenceTypePart::IntegerLike {
 
                    // Force integer type to s32
 
                    expr_type.parts[0] = InferenceTypePart::SInt32;
 
                    expr_type.is_done = true;
 

	
 
                    // Requeue expression (and its parent, if it exists)
 
                    self.node_queued.push_back(infer_node_index);
 
                    if let Some(node_parent_index) = infer_node.parent_index {
 
                        self.node_queued.push_back(node_parent_index);
 
                    }
 
                }
 
            }
 
        }
 

	
 
        // Helper for transferring polymorphic variables to concrete types and
 
        // checking if they're completely specified
 
        fn poly_data_type_to_concrete_type(
 
            ctx: &Ctx, expr_id: ExpressionId, inference_poly_args: &Vec<InferenceType>,
 
            first_concrete_part: ConcreteTypePart,
 
        ) -> Result<ConcreteType, ParseError> {
 
            // Prepare storage vector
 
            let mut num_inference_parts = 0;
 
            for inference_type in inference_poly_args {
 
                num_inference_parts += inference_type.parts.len();
 
            }
 

	
 
            let mut concrete_type = ConcreteType{
 
                parts: Vec::with_capacity(1 + num_inference_parts),
 
            };
 
            concrete_type.parts.push(first_concrete_part);
 

	
 
            // Go through all polymorphic arguments and add them to the concrete
 
            // types.
 
            for (poly_idx, poly_type) in inference_poly_args.iter().enumerate() {
 
                if !poly_type.is_done {
 
                    let expr = &ctx.heap[expr_id];
 
                    let definition = match expr {
 
                        Expression::Call(expr) => expr.procedure.upcast(),
 
                        Expression::Literal(expr) => match &expr.value {
 
                            Literal::Enum(lit) => lit.definition,
 
                            Literal::Union(lit) => lit.definition,
 
                            Literal::Struct(lit) => lit.definition,
 
                            _ => unreachable!()
 
                        },
 
                        _ => unreachable!(),
 
                    };
 
                    let poly_vars = ctx.heap[definition].poly_vars();
 
                    return Err(ParseError::new_error_at_span(
 
                        &ctx.module().source, expr.operation_span(), format!(
 
                            "could not fully infer the type of polymorphic variable '{}' of this expression (got '{}')",
 
                            poly_vars[poly_idx].value.as_str(), poly_type.display_name(&ctx.heap)
 
                        )
 
                    ));
 
                }
 

	
 
                poly_type.write_concrete_type(&mut concrete_type);
 
            }
 

	
 
            Ok(concrete_type)
 
        }
 

	
 
        // Every expression checked, and new monomorphs are queued. Transfer the
 
        // expression information to the AST. If this is the first time we're
 
        // visiting this procedure then we assign expression indices as well.
 
        let procedure = &ctx.heap[self.procedure_id];
 
        let num_infer_nodes = self.infer_nodes.len();
 
        let mut monomorph = ProcedureDefinitionMonomorph{
 
            argument_types: Vec::with_capacity(procedure.parameters.len()),
 
            expr_info: Vec::with_capacity(num_infer_nodes),
 
        };
 

	
 
        // For all of the expressions look up the TypeId (or create a new one).
 
        // For function calls and component instantiations figure out if they
 
        // need to be typechecked
 
        for infer_node in self.infer_nodes.iter_mut() {
 
            // Determine type ID
 
            let expr = &ctx.heap[infer_node.expr_id];
 

	
 
            // TODO: Maybe optimize? Split insertion up into lookup, then clone
 
            //  if needed?
 
            let mut concrete_type = ConcreteType::default();
 
            infer_node.expr_type.write_concrete_type(&mut concrete_type);
 
            let info_type_id = ctx.types.add_monomorphed_type(ctx.modules, ctx.heap, ctx.arch, concrete_type)?;
 

	
 
            // Determine procedure type ID, i.e. a called/instantiated
 
            // procedure's signature.
 
            let info_variant = if let Expression::Call(expr) = expr {
 
                // Construct full function type. If not yet typechecked then
 
                // queue it for typechecking.
 
                let poly_data = &self.poly_data[infer_node.poly_data_index as usize];
 
                debug_assert!(expr.method.is_user_defined() || expr.method.is_public_builtin());
 
                let procedure_id = expr.procedure;
 
                let num_poly_vars = poly_data.poly_vars.len() as u32;
 

	
 
                let first_part = match expr.method {
 
                    Method::UserFunction => ConcreteTypePart::Function(procedure_id, num_poly_vars),
 
                    Method::UserComponent => ConcreteTypePart::Component(procedure_id, num_poly_vars),
 
                    _ => ConcreteTypePart::Function(procedure_id, num_poly_vars),
 
                };
 

	
 

	
 
                let definition_id = procedure_id.upcast();
 
                let signature_type = poly_data_type_to_concrete_type(
 
                    ctx, infer_node.expr_id, &poly_data.poly_vars, first_part
 
                )?;
 

	
 
                let (type_id, monomorph_index) = if let Some(type_id) = ctx.types.get_procedure_monomorph_type_id(&definition_id, &signature_type.parts) {
 
                    // Procedure is already typechecked
 
                    let monomorph_index = ctx.types.get_monomorph(type_id).variant.as_procedure().monomorph_index;
 
                    (type_id, monomorph_index)
 
                } else {
 
                    // Procedure is not yet typechecked, reserve a TypeID and a monomorph index
 
                    let procedure_to_check = &mut ctx.heap[procedure_id];
 
                    let monomorph_index = procedure_to_check.monomorphs.len() as u32;
 
                    procedure_to_check.monomorphs.push(ProcedureDefinitionMonomorph::new_invalid());
 
                    let type_id = ctx.types.reserve_procedure_monomorph_type_id(&definition_id, signature_type, monomorph_index);
 

	
 
                    if !procedure_to_check.builtin {
 
                        // Only perform typechecking on the user-defined
 
                        // procedures
 
                        queue.push_back(ResolveQueueElement{
 
                            root_id: ctx.heap[definition_id].defined_in(),
 
                            definition_id,
 
                            reserved_type_id: type_id,
 
                            reserved_monomorph_index: monomorph_index,
 
                        });
 
                    }
 

	
 
                    (type_id, monomorph_index)
 
                };
 

	
 
                ExpressionInfoVariant::Procedure(type_id, monomorph_index)
 
            } else if let Expression::Select(_expr) = expr {
 
                ExpressionInfoVariant::Select(infer_node.field_index)
 
            } else {
 
                ExpressionInfoVariant::Generic
 
            };
 

	
 
            infer_node.info_type_id = info_type_id;
 
            infer_node.info_variant = info_variant;
 
        }
 

	
 
        // Write the types of the arguments
 
        let procedure = &ctx.heap[self.procedure_id];
 
        for parameter_id in procedure.parameters.iter().copied() {
 
            let mut concrete = ConcreteType::default();
 
            let var_data = self.var_data.iter().find(|v| v.var_id == parameter_id).unwrap();
 
            var_data.var_type.write_concrete_type(&mut concrete);
 
            let type_id = ctx.types.add_monomorphed_type(ctx.modules, ctx.heap, ctx.arch, concrete)?;
 
            monomorph.argument_types.push(type_id)
 
        }
 

	
 
        println!("DEBUG: For procedure {} with polyargs {:#?}", ctx.heap[self.procedure_id].identifier.value.as_str(), self.poly_vars);
 
        for infer_node in self.infer_nodes.iter() {
 
            println!("DEBUG: [{:?}] has type: {}", infer_node.expr_id, infer_node.expr_type.display_name(&ctx.heap));
 
        }
 

	
 
        // Determine if we have already assigned type indices to the expressions
 
        // before (the indices that, for a monomorph, can retrieve the type of
 
        // the expression).
 
        let has_type_indices = self.reserved_monomorph_index > 0;
 
        if has_type_indices {
 
            // already have indices, so resize and then index into it
 
            debug_assert!(monomorph.expr_info.is_empty());
 
            monomorph.expr_info.resize(num_infer_nodes, ExpressionInfo::new_invalid());
 
            for infer_node in self.infer_nodes.iter() {
 
                let type_index = ctx.heap[infer_node.expr_id].type_index();
 
                monomorph.expr_info[type_index as usize] = infer_node.as_expression_info();
 
            }
 
        } else {
 
            // no indices yet, need to be assigned in AST
 
            for infer_node in self.infer_nodes.iter() {
 
                let type_index = monomorph.expr_info.len();
 
                monomorph.expr_info.push(infer_node.as_expression_info());
 
                *ctx.heap[infer_node.expr_id].type_index_mut() = type_index as i32;
 
            }
 
        }
 

	
 
        // Push the information into the AST
 
        let procedure = &mut ctx.heap[self.procedure_id];
 
        procedure.monomorphs[self.reserved_monomorph_index as usize] = monomorph;
 

	
 
        Ok(())
 
    }
 

	
 
    fn progress_inference_rule(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> {
 
        use InferenceRule as IR;
 

	
 
        let node = &self.infer_nodes[node_index];
 
        match &node.inference_rule {
 
            IR::Noop =>
 
                unreachable!(),
 
            IR::MonoTemplate(_) =>
 
                self.progress_inference_rule_mono_template(ctx, node_index),
 
            IR::BiEqual(_) =>
 
                self.progress_inference_rule_bi_equal(ctx, node_index),
 
            IR::TriEqualArgs(_) =>
 
                self.progress_inference_rule_tri_equal_args(ctx, node_index),
 
            IR::TriEqualAll(_) =>
 
                self.progress_inference_rule_tri_equal_all(ctx, node_index),
 
            IR::Concatenate(_) =>
 
                self.progress_inference_rule_concatenate(ctx, node_index),
 
            IR::IndexingExpr(_) =>
 
                self.progress_inference_rule_indexing_expr(ctx, node_index),
 
            IR::SlicingExpr(_) =>
 
                self.progress_inference_rule_slicing_expr(ctx, node_index),
 
            IR::SelectStructField(_) =>
 
                self.progress_inference_rule_select_struct_field(ctx, node_index),
 
            IR::SelectTupleMember(_) =>
 
                self.progress_inference_rule_select_tuple_member(ctx, node_index),
 
            IR::LiteralStruct(_) =>
 
                self.progress_inference_rule_literal_struct(ctx, node_index),
 
            IR::LiteralEnum =>
 
                self.progress_inference_rule_literal_enum(ctx, node_index),
 
            IR::LiteralUnion(_) =>
 
                self.progress_inference_rule_literal_union(ctx, node_index),
 
            IR::LiteralArray(_) =>
 
                self.progress_inference_rule_literal_array(ctx, node_index),
 
            IR::LiteralTuple(_) =>
 
                self.progress_inference_rule_literal_tuple(ctx, node_index),
 
            IR::CastExpr(_) =>
 
                self.progress_inference_rule_cast_expr(ctx, node_index),
 
            IR::CallExpr(_) =>
 
                self.progress_inference_rule_call_expr(ctx, node_index),
 
            IR::VariableExpr(_) =>
 
                self.progress_inference_rule_variable_expr(ctx, node_index),
 
        }
 
    }
 

	
 
    fn progress_inference_rule_mono_template(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> {
 
        let node = &self.infer_nodes[node_index];
 
        let rule = *node.inference_rule.as_mono_template();
 

	
 
        let progress = self.progress_template(ctx, node_index, rule.application, rule.template)?;
 
        if progress { self.queue_node_parent(node_index); }
 

	
 
        return Ok(());
 
    }
 

	
 
    fn progress_inference_rule_bi_equal(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> {
 
        let node = &self.infer_nodes[node_index];
 
        let rule = node.inference_rule.as_bi_equal();
 
        let template = rule.template;
 
        let arg_index = rule.argument_index;
 

	
 
        let base_progress = self.progress_template(ctx, node_index, template.application, template.template)?;
 
        let (node_progress, arg_progress) = self.apply_equal2_constraint(ctx, node_index, node_index, 0, arg_index, 0)?;
 

	
 
        if base_progress || node_progress { self.queue_node_parent(node_index); }
 
        if arg_progress { self.queue_node(arg_index); }
 

	
 
        return Ok(())
 
    }
 

	
 
    fn progress_inference_rule_tri_equal_args(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> {
 
        let node = &self.infer_nodes[node_index];
 
        let rule = node.inference_rule.as_tri_equal_args();
 

	
 
        let result_template = rule.result_template;
 
        let argument_template = rule.argument_template;
 
        let arg1_index = rule.argument1_index;
 
        let arg2_index = rule.argument2_index;
 

	
 
        let self_template_progress = self.progress_template(ctx, node_index, result_template.application, result_template.template)?;
 
        let arg1_template_progress = self.progress_template(ctx, arg1_index, argument_template.application, argument_template.template)?;
 
        let (arg1_progress, arg2_progress) = self.apply_equal2_constraint(ctx, node_index, arg1_index, 0, arg2_index, 0)?;
 

	
 
        if self_template_progress { self.queue_node_parent(node_index); }
 
        if arg1_template_progress || arg1_progress { self.queue_node(arg1_index); }
 
        if arg2_progress { self.queue_node(arg2_index); }
 

	
 
        return Ok(());
 
    }
 

	
 
    fn progress_inference_rule_tri_equal_all(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> {
 
        let node = &self.infer_nodes[node_index];
 
        let rule = node.inference_rule.as_tri_equal_all();
 

	
 
        let template = rule.template;
 
        let arg1_index = rule.argument1_index;
 
        let arg2_index = rule.argument2_index;
 

	
 
        let template_progress = self.progress_template(ctx, node_index, template.application, template.template)?;
 
        let (node_progress, arg1_progress, arg2_progress) =
 
            self.apply_equal3_constraint(ctx, node_index, arg1_index, arg2_index, 0)?;
 

	
 
        if template_progress || node_progress { self.queue_node_parent(node_index); }
 
        if arg1_progress { self.queue_node(arg1_index); }
 
        if arg2_progress { self.queue_node(arg2_index); }
 

	
 
        return Ok(());
 
    }
 

	
 
    fn progress_inference_rule_concatenate(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> {
 
        let node = &self.infer_nodes[node_index];
 
        let rule = node.inference_rule.as_concatenate();
 
        let arg1_index = rule.argument1_index;
 
        let arg2_index = rule.argument2_index;
 

	
 
        // Two cases: one of the arguments is a string (then all must be), or
 
        // one of the arguments is an array (and all must be arrays).
 
        let (expr_is_str, expr_is_not_str) = self.type_is_certainly_or_certainly_not_string(node_index);
 
        let (arg1_is_str, arg1_is_not_str) = self.type_is_certainly_or_certainly_not_string(arg1_index);
 
        let (arg2_is_str, arg2_is_not_str) = self.type_is_certainly_or_certainly_not_string(arg2_index);
 

	
 
        let someone_is_str = expr_is_str || arg1_is_str || arg2_is_str;
 
        let someone_is_not_str = expr_is_not_str || arg1_is_not_str || arg2_is_not_str;
 
        // Note: this statement is an expression returning the progression bools
 
        let (node_progress, arg1_progress, arg2_progress) = if someone_is_str {
 
            // One of the arguments is a string, then all must be strings
 
            self.apply_equal3_constraint(ctx, node_index, arg1_index, arg2_index, 0)?
 
        } else {
 
            let progress_expr = if someone_is_not_str {
 
                // Output must be a normal array
 
                self.apply_template_constraint(ctx, node_index, &ARRAY_TEMPLATE)?
 
            } else {
 
                // Output may still be anything
 
                self.apply_template_constraint(ctx, node_index, &ARRAYLIKE_TEMPLATE)?
 
            };
 

	
 
            let progress_arg1 = self.apply_template_constraint(ctx, arg1_index, &ARRAYLIKE_TEMPLATE)?;
 
            let progress_arg2 = self.apply_template_constraint(ctx, arg2_index, &ARRAYLIKE_TEMPLATE)?;
 

	
 
            // If they're all arraylike, then we want the subtype to match
 
            let (subtype_expr, subtype_arg1, subtype_arg2) =
 
                self.apply_equal3_constraint(ctx, node_index, arg1_index, arg2_index, 1)?;
 

	
 
            (progress_expr || subtype_expr, progress_arg1 || subtype_arg1, progress_arg2 || subtype_arg2)
 
        };
 

	
 
        if node_progress { self.queue_node_parent(node_index); }
 
        if arg1_progress { self.queue_node(arg1_index); }
 
        if arg2_progress { self.queue_node(arg2_index); }
 

	
 
        return Ok(())
 
    }
 

	
 
    fn progress_inference_rule_indexing_expr(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> {
 
        let node = &self.infer_nodes[node_index];
 
        let rule = node.inference_rule.as_indexing_expr();
 
        let subject_index = rule.subject_index;
 
        let index_index = rule.index_index; // which one?
 

	
 
        // Subject is arraylike, index in integerlike
 
        let subject_template_progress = self.apply_template_constraint(ctx, subject_index, &ARRAYLIKE_TEMPLATE)?;
 
        let index_template_progress = self.apply_template_constraint(ctx, index_index, &INTEGERLIKE_TEMPLATE)?;
 

	
 
        // If subject is type `Array<T>`, then expr type is `T`
 
        let (node_progress, subject_progress) =
src/protocol/parser/type_table.rs
Show inline comments
 
@@ -280,386 +280,386 @@ pub struct UnionMonomorphVariant {
 
pub struct UnionMonomorphEmbedded {
 
    pub type_id: TypeId,
 
    concrete_type: ConcreteType,
 
    // Note that the meaning of the offset (and alignment) depend on whether or
 
    // not the variant lives on the stack/heap. If it lives on the stack then
 
    // they refer to the offset from the start of the union value (so the first
 
    // embedded type lives at a non-zero offset, because the union tag sits in
 
    // the front). If it lives on the heap then it refers to the offset from the
 
    // allocated memory region (so the first embedded type lives at a 0 offset).
 
    pub size: usize,
 
    pub alignment: usize,
 
    pub offset: usize,
 
}
 

	
 
/// Procedure (functions and components of all possible types) monomorph. Also
 
/// stores the expression type data from the typechecking/inferencing pass.
 
pub struct ProcedureMonomorph {
 
    pub monomorph_index: u32,
 
    pub builtin: bool,
 
}
 

	
 
/// Tuple monomorph. Again a kind of exception because one cannot define a named
 
/// tuple type containing explicit polymorphic variables. But again: we need to
 
/// store size/offset/alignment information, so we do it here.
 
pub struct TupleMonomorph {
 
    pub members: Vec<TupleMonomorphMember>
 
}
 

	
 
pub struct TupleMonomorphMember {
 
    pub type_id: TypeId,
 
    concrete_type: ConcreteType,
 
    pub size: usize,
 
    pub alignment: usize,
 
    pub offset: usize,
 
}
 

	
 
/// Generic unique type ID. Every monomorphed type and every non-polymorphic
 
/// type will have one of these associated with it.
 
#[derive(Debug, Clone, Copy, PartialEq)]
 
pub struct TypeId(i64);
 

	
 
impl TypeId {
 
    pub(crate) fn new_invalid() -> Self {
 
        return Self(-1);
 
    }
 
}
 

	
 
/// A monomorphed type (or non-polymorphic type's) memory layout and information
 
/// regarding associated types (like a struct's field type).
 
pub struct MonoType {
 
    pub type_id: TypeId,
 
    pub concrete_type: ConcreteType,
 
    pub size: usize,
 
    pub alignment: usize,
 
    pub(crate) variant: MonoTypeVariant
 
}
 

	
 
impl MonoType {
 
    #[inline]
 
    fn new_empty(type_id: TypeId, concrete_type: ConcreteType, variant: MonoTypeVariant) -> Self {
 
        return Self {
 
            type_id, concrete_type,
 
            size: 0,
 
            alignment: 0,
 
            variant,
 
        }
 
    }
 

	
 
    /// Little internal helper function as a reminder: if alignment is 0, then
 
    /// the size/alignment are not actually computed yet!
 
    #[inline]
 
    fn get_size_alignment(&self) -> Option<(usize, usize)> {
 
        if self.alignment == 0 {
 
            return None
 
        } else {
 
            return Some((self.size, self.alignment));
 
        }
 
    }
 
}
 

	
 
/// Special structure that acts like the lookup key for `ConcreteType` instances
 
/// that have already been added to the type table before.
 
#[derive(Clone)]
 
struct MonoSearchKey {
 
    // Uses bitflags to denote when parts between search keys should match and
 
    // whether they should be checked. Needs to have a system like this to
 
    // accommodate tuples.
 
    parts: Vec<(u8, ConcreteTypePart)>,
 
    change_bit: u8,
 
}
 

	
 
impl MonoSearchKey {
 
    const KEY_IN_USE: u8 = 0x01;
 
    const KEY_CHANGE_BIT: u8 = 0x02;
 

	
 
    fn with_capacity(capacity: usize) -> Self {
 
        return MonoSearchKey{
 
            parts: Vec::with_capacity(capacity),
 
            change_bit: 0,
 
        };
 
    }
 

	
 
    /// Sets the search key based on a single concrete type and its polymorphic
 
    /// variables.
 
    fn set(&mut self, concrete_type_parts: &[ConcreteTypePart], poly_var_in_use: &[PolymorphicVariable]) {
 
        self.set_top_type(concrete_type_parts[0]);
 

	
 
        let mut poly_var_index = 0;
 
        for subtype in ConcreteTypeIter::new(concrete_type_parts, 0) {
 
            let in_use = poly_var_in_use[poly_var_index].is_in_use;
 
            poly_var_index += 1;
 
            self.push_subtype(subtype, in_use);
 
        }
 

	
 
        debug_assert_eq!(poly_var_index, poly_var_in_use.len());
 
    }
 

	
 
    /// Starts setting the search key based on an initial top-level type,
 
    /// programmer must call `push_subtype` the appropriate number of times
 
    /// after calling this function
 
    fn set_top_type(&mut self, type_part: ConcreteTypePart) {
 
        self.parts.clear();
 
        self.parts.push((Self::KEY_IN_USE, type_part));
 
        self.change_bit = Self::KEY_CHANGE_BIT;
 
    }
 

	
 
    fn push_subtype(&mut self, concrete_type: &[ConcreteTypePart], in_use: bool) {
 
        let flag = self.change_bit | (if in_use { Self::KEY_IN_USE } else { 0 });
 

	
 
        for part in concrete_type {
 
            self.parts.push((flag, *part));
 
        }
 
        self.change_bit ^= Self::KEY_CHANGE_BIT;
 
    }
 

	
 
    fn push_subtree(&mut self, concrete_type: &[ConcreteTypePart], poly_var_in_use: &[PolymorphicVariable]) {
 
        self.parts.push((self.change_bit | Self::KEY_IN_USE, concrete_type[0]));
 
        self.change_bit ^= Self::KEY_CHANGE_BIT;
 

	
 
        let mut poly_var_index = 0;
 
        for subtype in ConcreteTypeIter::new(concrete_type, 0) {
 
            let in_use = poly_var_in_use[poly_var_index].is_in_use;
 
            poly_var_index += 1;
 
            self.push_subtype(subtype, in_use);
 
        }
 

	
 
        debug_assert_eq!(poly_var_index, poly_var_in_use.len());
 
    }
 

	
 
    // Utilities for hashing and comparison
 
    fn find_end_index(&self, start_index: usize) -> usize {
 
        // Check if we're already at the end
 
        let mut index = start_index;
 
        if index >= self.parts.len() {
 
            return index;
 
        }
 

	
 
        // Iterate until bit flips, or until at end
 
        let expected_bit = self.parts[index].0 & Self::KEY_CHANGE_BIT;
 

	
 
        index += 1;
 
        while index < self.parts.len() {
 
            let current_bit = self.parts[index].0 & Self::KEY_CHANGE_BIT;
 
            if current_bit != expected_bit {
 
                return index;
 
            }
 

	
 
            index += 1;
 
        }
 

	
 
        return index;
 
    }
 
}
 

	
 
impl Hash for MonoSearchKey {
 
    fn hash<H: Hasher>(&self, state: &mut H) {
 
        for index in 0..self.parts.len() {
 
            let (_flags, part) = self.parts[index];
 
            // if flags & Self::KEY_IN_USE != 0 { @Deduplication
 
            part.hash(state);
 
            // }
 
        }
 
    }
 
}
 

	
 
impl PartialEq for MonoSearchKey {
 
    fn eq(&self, other: &Self) -> bool {
 
        let mut self_index = 0;
 
        let mut other_index = 0;
 

	
 
        while self_index < self.parts.len() && other_index < other.parts.len() {
 
            // Retrieve part and flags
 
            let (self_bits, _) = self.parts[self_index];
 
            let (other_bits, _) = other.parts[other_index];
 
            let (_self_bits, _) = self.parts[self_index];
 
            let (_other_bits, _) = other.parts[other_index];
 
            let self_in_use = true; // (self_bits & Self::KEY_IN_USE) != 0; @Deduplication
 
            let other_in_use = true; // (other_bits & Self::KEY_IN_USE) != 0; @Deduplication
 

	
 
            // Determine ending indices
 
            let self_end_index = self.find_end_index(self_index);
 
            let other_end_index = other.find_end_index(other_index);
 

	
 
            if self_in_use == other_in_use {
 
                if self_in_use {
 
                    // Both are in use, so both parts should be equal
 
                    let delta_self = self_end_index - self_index;
 
                    let delta_other = other_end_index - other_index;
 
                    if delta_self != delta_other {
 
                        // Both in use, but not of equal length, so the types
 
                        // cannot match
 
                        return false;
 
                    }
 

	
 
                    for _ in 0..delta_self {
 
                        let (_, self_part) = self.parts[self_index];
 
                        let (_, other_part) = other.parts[other_index];
 

	
 
                        if self_part != other_part {
 
                            return false;
 
                        }
 

	
 
                        self_index += 1;
 
                        other_index += 1;
 
                    }
 
                } else {
 
                    // Both not in use, so skip associated parts
 
                    self_index = self_end_index;
 
                    other_index = other_end_index;
 
                }
 
            } else {
 
                // No agreement on importance of parts. This is practically
 
                // impossible
 
                unreachable!();
 
            }
 
        }
 

	
 
        // Everything matched, so if we're at the end of both arrays then we're
 
        // certain that the two keys are equal.
 
        return self_index == self.parts.len() && other_index == other.parts.len();
 
    }
 
}
 

	
 
impl Eq for MonoSearchKey{}
 

	
 
//------------------------------------------------------------------------------
 
// Type table
 
//------------------------------------------------------------------------------
 

	
 
const POLY_VARS_IN_USE: [PolymorphicVariable; 1] = [PolymorphicVariable{ identifier: Identifier::new_empty(InputSpan::new()), is_in_use: true }];
 

	
 
// Programmer note: keep this struct free of dynamically allocated memory
 
#[derive(Clone)]
 
struct TypeLoopBreadcrumb {
 
    type_id: TypeId,
 
    next_member: u32,
 
    next_embedded: u32, // for unions, the index into the variant's embedded types
 
}
 

	
 
// Programmer note: keep this struct free of dynamically allocated memory
 
#[derive(Clone)]
 
struct MemoryBreadcrumb {
 
    type_id: TypeId,
 
    next_member: u32,
 
    next_embedded: u32,
 
    first_size_alignment_idx: u32,
 
}
 

	
 
#[derive(Debug, PartialEq, Eq)]
 
enum TypeLoopResult {
 
    TypeExists,
 
    PushBreadcrumb(DefinitionId, ConcreteType),
 
    TypeLoop(usize), // index into vec of breadcrumbs at which the type matched
 
}
 

	
 
enum MemoryLayoutResult {
 
    TypeExists(usize, usize), // (size, alignment)
 
    PushBreadcrumb(MemoryBreadcrumb),
 
}
 

	
 
// TODO: @Optimize, initial memory-unoptimized implementation
 
struct TypeLoopEntry {
 
    type_id: TypeId,
 
    is_union: bool,
 
}
 

	
 
struct TypeLoop {
 
    members: Vec<TypeLoopEntry>,
 
}
 

	
 
type DefinitionMap = HashMap<DefinitionId, DefinedType>;
 
type MonoTypeMap = HashMap<MonoSearchKey, TypeId>;
 
type MonoTypeArray = Vec<MonoType>;
 

	
 
pub struct TypeTable {
 
    // Lookup from AST DefinitionId to a defined type. Also lookups for
 
    // concrete type to monomorphs
 
    pub(crate) definition_lookup: DefinitionMap,
 
    mono_type_lookup: MonoTypeMap,
 
    pub(crate) mono_types: MonoTypeArray,
 
    mono_search_key: MonoSearchKey,
 
    // Breadcrumbs left behind while trying to find type loops. Also used to
 
    // determine sizes of types when all type loops are detected.
 
    type_loop_breadcrumbs: Vec<TypeLoopBreadcrumb>,
 
    type_loops: Vec<TypeLoop>,
 
    // Stores all encountered types during type loop detection. Used afterwards
 
    // to iterate over all types in order to compute size/alignment.
 
    encountered_types: Vec<TypeLoopEntry>,
 
    // Breadcrumbs and temporary storage during memory layout computation.
 
    memory_layout_breadcrumbs: Vec<MemoryBreadcrumb>,
 
    size_alignment_stack: Vec<(usize, usize)>,
 
}
 

	
 
impl TypeTable {
 
    /// Construct a new type table without any resolved types.
 
    pub(crate) fn new() -> Self {
 
        Self{ 
 
            definition_lookup: HashMap::with_capacity(128),
 
            mono_type_lookup: HashMap::with_capacity(128),
 
            mono_types: Vec::with_capacity(128),
 
            mono_search_key: MonoSearchKey::with_capacity(32),
 
            type_loop_breadcrumbs: Vec::with_capacity(32),
 
            type_loops: Vec::with_capacity(8),
 
            encountered_types: Vec::with_capacity(32),
 
            memory_layout_breadcrumbs: Vec::with_capacity(32),
 
            size_alignment_stack: Vec::with_capacity(64),
 
        }
 
    }
 

	
 
    /// Iterates over all defined types (polymorphic and non-polymorphic) and
 
    /// add their types in two passes. In the first pass we will just add the
 
    /// base types (we will not consider monomorphs, and we will not compute
 
    /// byte sizes). In the second pass we will compute byte sizes of
 
    /// non-polymorphic types, and potentially the monomorphs that are embedded
 
    /// in those types.
 
    pub(crate) fn build_base_types(&mut self, modules: &mut [Module], ctx: &mut PassCtx) -> Result<(), ParseError> {
 
        // Make sure we're allowed to cast root_id to index into ctx.modules
 
        debug_assert!(modules.iter().all(|m| m.phase >= ModuleCompilationPhase::DefinitionsParsed));
 
        debug_assert!(self.definition_lookup.is_empty());
 

	
 
        dbg_code!({
 
            for (index, module) in modules.iter().enumerate() {
 
                debug_assert_eq!(index, module.root_id.index as usize);
 
            }
 
        });
 

	
 
        // Use context to guess hashmap size of the base types
 
        let reserve_size = ctx.heap.definitions.len();
 
        self.definition_lookup.reserve(reserve_size);
 

	
 
        // Resolve all base types
 
        for definition_idx in 0..ctx.heap.definitions.len() {
 
            let definition_id = ctx.heap.definitions.get_id(definition_idx);
 
            let definition = &ctx.heap[definition_id];
 

	
 
            match definition {
 
                Definition::Enum(_) => self.build_base_enum_definition(modules, ctx, definition_id)?,
 
                Definition::Union(_) => self.build_base_union_definition(modules, ctx, definition_id)?,
 
                Definition::Struct(_) => self.build_base_struct_definition(modules, ctx, definition_id)?,
 
                Definition::Procedure(_) => self.build_base_procedure_definition(modules, ctx, definition_id)?,
 
            }
 
        }
 

	
 
        debug_assert_eq!(self.definition_lookup.len(), reserve_size, "mismatch in reserved size of type table");
 
        for module in modules.iter_mut() {
 
            module.phase = ModuleCompilationPhase::TypesAddedToTable;
 
        }
 

	
 
        // Go through all types again, lay out all types that are not
 
        // polymorphic. This might cause us to lay out monomorphized polymorphs
 
        // if these were member types of non-polymorphic types.
 
        for definition_idx in 0..ctx.heap.definitions.len() {
 
            let definition_id = ctx.heap.definitions.get_id(definition_idx);
 
            let poly_type = self.definition_lookup.get(&definition_id).unwrap();
 

	
 
            if !poly_type.definition.is_data_type() || !poly_type.poly_vars.is_empty() {
 
                continue;
 
            }
 

	
 
            // If here then the type is a data type without polymorphic
 
            // variables, but we might have instantiated it already, so:
 
            let concrete_parts = [ConcreteTypePart::Instance(definition_id, 0)];
 
            self.mono_search_key.set(&concrete_parts, &[]);
 
            let type_id = self.mono_type_lookup.get(&self.mono_search_key);
 
            if type_id.is_none() {
 
                self.detect_and_resolve_type_loops_for(
 
                    modules, ctx.heap, ctx.arch,
 
                    ConcreteType{
 
@@ -1334,393 +1334,395 @@ impl TypeTable {
 
        // loop and that union ended up having variants that are not part of
 
        // a type loop.
 
        fn type_loop_source_span_and_message<'a>(
 
            modules: &'a [Module], heap: &Heap, mono_types: &MonoTypeArray,
 
            definition_id: DefinitionId, mono_type_id: TypeId, index_in_loop: usize
 
        ) -> (&'a InputSource, InputSpan, String) {
 
            // Note: because we will discover the type loop the *first* time we
 
            // instantiate a monomorph with the provided polymorphic arguments
 
            // (not all arguments are actually used in the type). We don't have
 
            // to care about a second instantiation where certain unused
 
            // polymorphic arguments are different.
 
            let mono_type = &mono_types[mono_type_id.0 as usize];
 
            let type_name = mono_type.concrete_type.display_name(heap);
 

	
 
            let message = if index_in_loop == 0 {
 
                format!(
 
                    "encountered an infinitely large type for '{}' (which can be fixed by \
 
                    introducing a union type that has a variant whose embedded types are \
 
                    not part of a type loop, or do not have embedded types)",
 
                    type_name
 
                )
 
            } else if index_in_loop == 1 {
 
                format!("because it depends on the type '{}'", type_name)
 
            } else {
 
                format!("which depends on the type '{}'", type_name)
 
            };
 

	
 
            let ast_definition = &heap[definition_id];
 
            let ast_root_id = ast_definition.defined_in();
 

	
 
            return (
 
                &modules[ast_root_id.index as usize].source,
 
                ast_definition.identifier().span,
 
                message
 
            );
 
        }
 

	
 
        fn construct_type_loop_error(mono_types: &MonoTypeArray, type_loop: &TypeLoop, modules: &[Module], heap: &Heap) -> ParseError {
 
            // Seek first entry to produce parse error. Then continue builder
 
            // pattern. This is the error case so efficiency can go home.
 
            let mut parse_error = None;
 
            let mut next_member_index = 0;
 
            while next_member_index < type_loop.members.len() {
 
                let first_entry = &type_loop.members[next_member_index];
 
                next_member_index += 1;
 

	
 
                // Retrieve definition of first type in loop
 
                let first_mono_type = &mono_types[first_entry.type_id.0 as usize];
 
                let first_definition_id = get_concrete_type_definition(&first_mono_type.concrete_type.parts);
 
                if first_definition_id.is_none() {
 
                    continue;
 
                }
 
                let first_definition_id = first_definition_id.unwrap();
 

	
 
                // Produce error message for first type in loop
 
                let (first_module, first_span, first_message) = type_loop_source_span_and_message(
 
                    modules, heap, mono_types, first_definition_id, first_entry.type_id, 0
 
                );
 
                parse_error = Some(ParseError::new_error_at_span(first_module, first_span, first_message));
 
                break;
 
            }
 

	
 
            let mut parse_error = parse_error.unwrap(); // Loop above cannot have failed, because we must have a type loop, type loops cannot contain only unnamed types
 

	
 
            let mut error_counter = 1;
 
            for member_idx in next_member_index..type_loop.members.len() {
 
                let entry = &type_loop.members[member_idx];
 
                let mono_type = &mono_types[entry.type_id.0 as usize];
 
                let definition_id = get_concrete_type_definition(&mono_type.concrete_type.parts);
 
                if definition_id.is_none() {
 
                    continue;
 
                }
 
                let definition_id = definition_id.unwrap();
 

	
 
                let (module, span, message) = type_loop_source_span_and_message(
 
                    modules, heap, mono_types, definition_id, entry.type_id, error_counter
 
                );
 
                parse_error = parse_error.with_info_at_span(module, span, message);
 
                error_counter += 1;
 
            }
 

	
 
            parse_error
 
        }
 

	
 
        for type_loop in &self.type_loops {
 
            let mut can_be_broken = false;
 
            debug_assert!(!type_loop.members.is_empty());
 

	
 
            for entry in &type_loop.members {
 
                if entry.is_union {
 
                    let mono_type = self.mono_types[entry.type_id.0 as usize].variant.as_union();
 
                    debug_assert!(!mono_type.variants.is_empty()); // otherwise it couldn't be part of the type loop
 
                    let has_stack_variant = mono_type.variants.iter().any(|variant| !variant.lives_on_heap);
 
                    if has_stack_variant {
 
                        can_be_broken = true;
 
                        break;
 
                    }
 
                }
 
            }
 

	
 
            if !can_be_broken {
 
                // Construct a type loop error
 
                return Err(construct_type_loop_error(&self.mono_types, type_loop, modules, heap));
 
            }
 
        }
 

	
 
        // If here, then all type loops have been resolved and we can lay out
 
        // all of the members
 
        self.type_loops.clear();
 

	
 
        return Ok(());
 
    }
 

	
 
    /// Checks if the specified type needs to be resolved (i.e. we need to push
 
    /// a breadcrumb), is already resolved (i.e. we can continue with the next
 
    /// member of the currently considered type) or is in the process of being
 
    /// resolved (i.e. we're in a type loop). Because of borrowing rules we
 
    /// don't do any modifications of internal types here. Hence: if we
 
    /// return `PushBreadcrumb` then call `handle_new_breadcrumb_for_type_loops`
 
    /// to take care of storing the appropriate types.
 
    fn check_member_for_type_loops(
 
        breadcrumbs: &[TypeLoopBreadcrumb], definition_map: &DefinitionMap, mono_type_map: &MonoTypeMap,
 
        mono_key: &mut MonoSearchKey, concrete_type: &ConcreteType
 
    ) -> TypeLoopResult {
 
        use ConcreteTypePart as CTP;
 

	
 
        // Depending on the type, lookup if the type has already been visited
 
        // (i.e. either already has its memory layed out, or is part of a type
 
        // loop because we've already visited the type)
 
        debug_assert!(!concrete_type.parts.is_empty());
 
        let definition_id = if let ConcreteTypePart::Instance(definition_id, _) = concrete_type.parts[0] {
 
            definition_id
 
        } else {
 
            DefinitionId::new_invalid()
 
        };
 

	
 
        Self::set_search_key_to_type(mono_key, definition_map, &concrete_type.parts);
 
        if let Some(type_id) = mono_type_map.get(mono_key).copied() {
 
            for (breadcrumb_idx, breadcrumb) in breadcrumbs.iter().enumerate() {
 
                if breadcrumb.type_id == type_id {
 
                    return TypeLoopResult::TypeLoop(breadcrumb_idx);
 
                }
 
            }
 

	
 
            return TypeLoopResult::TypeExists;
 
        }
 

	
 
        // Type is not yet known, so we need to insert it into the lookup and
 
        // push a new breadcrumb.
 
        return TypeLoopResult::PushBreadcrumb(definition_id, concrete_type.clone());
 
    }
 

	
 
    /// Handles the `PushBreadcrumb` result for a `check_member_for_type_loops`
 
    /// call. Will preallocate entries in the monomorphed type storage (with
 
    /// all memory properties zeroed).
 
    fn handle_new_breadcrumb_for_type_loops(&mut self, arch: &TargetArch, definition_id: DefinitionId, concrete_type: ConcreteType) {
 
        use DefinedTypeVariant as DTV;
 
        use ConcreteTypePart as CTP;
 

	
 
        let mut is_union = false;
 

	
 
        let type_id = match &concrete_type.parts[0] {
 
            // Builtin types
 
            CTP::Void | CTP::Message | CTP::Bool |
 
            CTP::UInt8 | CTP::UInt16 | CTP::UInt32 | CTP::UInt64 |
 
            CTP::SInt8 | CTP::SInt16 | CTP::SInt32 | CTP::SInt64 |
 
            CTP::Character | CTP::String |
 
            CTP::Array | CTP::Slice | CTP::Input | CTP::Output | CTP::Pointer => {
 
                // Insert the entry for the builtin type, we should be able to
 
                // immediately "steal" the size from the preinserted builtins.
 
                let base_type_id = match &concrete_type.parts[0] {
 
                    CTP::Void => arch.void_type_id,
 
                    CTP::Message => arch.message_type_id,
 
                    CTP::Bool => arch.bool_type_id,
 
                    CTP::UInt8 => arch.uint8_type_id,
 
                    CTP::UInt16 => arch.uint16_type_id,
 
                    CTP::UInt32 => arch.uint32_type_id,
 
                    CTP::UInt64 => arch.uint64_type_id,
 
                    CTP::SInt8 => arch.sint8_type_id,
 
                    CTP::SInt16 => arch.sint16_type_id,
 
                    CTP::SInt32 => arch.sint32_type_id,
 
                    CTP::SInt64 => arch.sint64_type_id,
 
                    CTP::Character => arch.char_type_id,
 
                    CTP::String => arch.string_type_id,
 
                    CTP::Array => arch.array_type_id,
 
                    CTP::Slice => arch.slice_type_id,
 
                    CTP::Input => arch.input_type_id,
 
                    CTP::Output => arch.output_type_id,
 
                    CTP::Pointer => arch.pointer_type_id,
 
                    _ => unreachable!(),
 
                };
 
                let base_type = &self.mono_types[base_type_id.0 as usize];
 
                let base_type_size = base_type.size;
 
                let base_type_alignment = base_type.alignment;
 

	
 
                let type_id = TypeId(self.mono_types.len() as i64);
 
                Self::set_search_key_to_type(&mut self.mono_search_key, &self.definition_lookup, &concrete_type.parts);
 
                self.mono_type_lookup.insert(self.mono_search_key.clone(), type_id);
 
                self.mono_types.push(MonoType{
 
                    type_id,
 
                    concrete_type,
 
                    size: base_type.size,
 
                    alignment: base_type.alignment,
 
                    size: base_type_size,
 
                    alignment: base_type_alignment,
 
                    variant: MonoTypeVariant::Builtin
 
                });
 

	
 
                type_id
 
            },
 
            // User-defined types
 
            CTP::Tuple(num_embedded) => {
 
                debug_assert!(definition_id.is_invalid()); // because tuples do not have an associated `DefinitionId`
 
                let mut members = Vec::with_capacity(*num_embedded as usize);
 
                for section in ConcreteTypeIter::new(&concrete_type.parts, 0) {
 
                    members.push(TupleMonomorphMember{
 
                        type_id: TypeId::new_invalid(),
 
                        concrete_type: ConcreteType{ parts: Vec::from(section) },
 
                        size: 0,
 
                        alignment: 0,
 
                        offset: 0
 
                    });
 
                }
 

	
 
                let type_id = TypeId(self.mono_types.len() as i64);
 
                Self::set_search_key_to_tuple(&mut self.mono_search_key, &self.definition_lookup, &concrete_type.parts);
 
                self.mono_type_lookup.insert(self.mono_search_key.clone(), type_id);
 
                self.mono_types.push(MonoType::new_empty(type_id, concrete_type, MonoTypeVariant::Tuple(TupleMonomorph{ members })));
 

	
 
                type_id
 
            },
 
            CTP::Instance(_check_definition_id, _) => {
 
                debug_assert_eq!(definition_id, *_check_definition_id); // because this is how `definition_id` was determined
 

	
 
                Self::set_search_key_to_type(&mut self.mono_search_key, &self.definition_lookup, &concrete_type.parts);
 
                let base_type = self.definition_lookup.get(&definition_id).unwrap();
 
                let type_id = match &base_type.definition {
 
                    DTV::Enum(definition) => {
 
                        // The enum is a bit exceptional in that when we insert
 
                        // it we we will immediately set its size/alignment:
 
                        // there is nothing to compute here.
 
                        debug_assert!(definition.size != 0 && definition.alignment != 0);
 
                        let type_id = TypeId(self.mono_types.len() as i64);
 
                        self.mono_type_lookup.insert(self.mono_search_key.clone(), type_id);
 
                        self.mono_types.push(MonoType::new_empty(type_id, concrete_type, MonoTypeVariant::Enum));
 

	
 
                        let mono_type = &mut self.mono_types[type_id.0 as usize];
 
                        mono_type.size = definition.size;
 
                        mono_type.alignment = definition.alignment;
 

	
 
                        type_id
 
                    },
 
                    DTV::Union(definition) => {
 
                        // Create all the variants with their concrete types
 
                        let mut mono_variants = Vec::with_capacity(definition.variants.len());
 
                        for poly_variant in &definition.variants {
 
                            let mut mono_embedded = Vec::with_capacity(poly_variant.embedded.len());
 
                            for poly_embedded in &poly_variant.embedded {
 
                                let mono_concrete = Self::construct_concrete_type(poly_embedded, &concrete_type);
 
                                mono_embedded.push(UnionMonomorphEmbedded{
 
                                    type_id: TypeId::new_invalid(),
 
                                    concrete_type: mono_concrete,
 
                                    size: 0,
 
                                    alignment: 0,
 
                                    offset: 0
 
                                });
 
                            }
 

	
 
                            mono_variants.push(UnionMonomorphVariant{
 
                                lives_on_heap: false,
 
                                embedded: mono_embedded,
 
                            })
 
                        }
 

	
 
                        let type_id = TypeId(self.mono_types.len() as i64);
 
                        let tag_size = definition.tag_size;
 
                        Self::set_search_key_to_type(&mut self.mono_search_key, &self.definition_lookup, &concrete_type.parts);
 
                        self.mono_type_lookup.insert(self.mono_search_key.clone(), type_id);
 
                        self.mono_types.push(MonoType::new_empty(type_id, concrete_type, MonoTypeVariant::Union(UnionMonomorph{
 
                            variants: mono_variants,
 
                            tag_size,
 
                            heap_size: 0,
 
                            heap_alignment: 0,
 
                        })));
 

	
 
                        is_union = true;
 
                        type_id
 
                    },
 
                    DTV::Struct(definition) => {
 
                        // Create fields
 
                        let mut mono_fields = Vec::with_capacity(definition.fields.len());
 
                        for poly_field in &definition.fields {
 
                            let mono_concrete = Self::construct_concrete_type(&poly_field.parser_type, &concrete_type);
 
                            mono_fields.push(StructMonomorphField{
 
                                type_id: TypeId::new_invalid(),
 
                                concrete_type: mono_concrete,
 
                                size: 0,
 
                                alignment: 0,
 
                                offset: 0
 
                            })
 
                        }
 

	
 
                        let type_id = TypeId(self.mono_types.len() as i64);
 
                        Self::set_search_key_to_type(&mut self.mono_search_key, &self.definition_lookup, &concrete_type.parts);
 
                        self.mono_type_lookup.insert(self.mono_search_key.clone(), type_id);
 
                        self.mono_types.push(MonoType::new_empty(type_id, concrete_type, MonoTypeVariant::Struct(StructMonomorph{
 
                            fields: mono_fields,
 
                        })));
 

	
 
                        type_id
 
                    },
 
                    DTV::Procedure(_) => {
 
                        unreachable!("pushing type resolving breadcrumb for procedure type")
 
                    },
 
                };
 

	
 
                type_id
 
            },
 
            CTP::Function(_, _) | CTP::Component(_, _) => todo!("function pointers"),
 
        };
 

	
 
        self.encountered_types.push(TypeLoopEntry{ type_id, is_union });
 
        self.type_loop_breadcrumbs.push(TypeLoopBreadcrumb{
 
            type_id,
 
            next_member: 0,
 
            next_embedded: 0,
 
        });
 
    }
 

	
 
    /// Constructs a concrete type out of a parser type for a struct field or
 
    /// union embedded type. It will do this by looking up the polymorphic
 
    /// variables in the supplied concrete type. The assumption is that the
 
    /// polymorphic variable's indices correspond to the subtrees in the
 
    /// concrete type.
 
    fn construct_concrete_type(member_type: &ParserType, container_type: &ConcreteType) -> ConcreteType {
 
        use ParserTypeVariant as PTV;
 
        use ConcreteTypePart as CTP;
 

	
 
        // TODO: Combine with code in pass_typing.rs
 
        fn parser_to_concrete_part(part: &ParserTypeVariant) -> Option<ConcreteTypePart> {
 
            match part {
 
                PTV::Void      => Some(CTP::Void),
 
                PTV::Message   => Some(CTP::Message),
 
                PTV::Bool      => Some(CTP::Bool),
 
                PTV::UInt8     => Some(CTP::UInt8),
 
                PTV::UInt16    => Some(CTP::UInt16),
 
                PTV::UInt32    => Some(CTP::UInt32),
 
                PTV::UInt64    => Some(CTP::UInt64),
 
                PTV::SInt8     => Some(CTP::SInt8),
 
                PTV::SInt16    => Some(CTP::SInt16),
 
                PTV::SInt32    => Some(CTP::SInt32),
 
                PTV::SInt64    => Some(CTP::SInt64),
 
                PTV::Character => Some(CTP::Character),
 
                PTV::String    => Some(CTP::String),
 
                PTV::Array     => Some(CTP::Array),
 
                PTV::Input     => Some(CTP::Input),
 
                PTV::Output    => Some(CTP::Output),
 
                PTV::Tuple(num) => Some(CTP::Tuple(*num)),
 
                PTV::Definition(definition_id, num) => Some(CTP::Instance(*definition_id, *num)),
 
                _              => None
 
            }
 
        }
 

	
 
        let mut parts = Vec::with_capacity(member_type.elements.len()); // usually a correct estimation, might not be
 
        for member_part in &member_type.elements {
 
            // Check if we have a regular builtin type
 
            if let Some(part) = parser_to_concrete_part(&member_part.variant) {
 
                parts.push(part);
 
                continue;
 
            }
 

	
 
            // Not builtin, but if all code is working correctly, we only care
 
            // about the polymorphic argument at this point.
 
            if let PTV::PolymorphicArgument(_container_definition_id, poly_arg_idx) = member_part.variant {
 
                debug_assert_eq!(_container_definition_id, get_concrete_type_definition(&container_type.parts).unwrap());
 

	
 
                let mut container_iter = container_type.embedded_iter(0);
 
                for _ in 0..poly_arg_idx {
 
                    container_iter.next();
 
                }
 

	
 
                let poly_section = container_iter.next().unwrap();
 
                parts.extend(poly_section);
 

	
 
                continue;
 
            }
 

	
 
            unreachable!("unexpected type part {:?} from {:?}", member_part, member_type);
 
        }
 

	
 
        return ConcreteType{ parts };
 
    }
 

	
 
    //--------------------------------------------------------------------------
 
    // Determining memory layout for types
 
    //--------------------------------------------------------------------------
 

	
src/protocol/tests/utils.rs
Show inline comments
 
@@ -1083,196 +1083,194 @@ fn seek_def_in_modules<'a>(heap: &Heap, modules: &'a [Module], def_id: Definitio
 
    }
 

	
 
    None
 
}
 

	
 
fn seek_stmt<F: Fn(&Statement) -> bool>(heap: &Heap, start: StatementId, f: &F) -> Option<StatementId> {
 
    let stmt = &heap[start];
 
    if f(stmt) { return Some(start); }
 

	
 
    // This statement wasn't it, try to recurse
 
    let matched = match stmt {
 
        Statement::Block(block) => {
 
            for sub_id in &block.statements {
 
                if let Some(id) = seek_stmt(heap, *sub_id, f) {
 
                    return Some(id);
 
                }
 
            }
 

	
 
            None
 
        },
 
        Statement::Labeled(stmt) => seek_stmt(heap, stmt.body, f),
 
        Statement::If(stmt) => {
 
            if let Some(id) = seek_stmt(heap, stmt.true_case.body, f) {
 
                return Some(id);
 
            } else if let Some(false_body) = stmt.false_case {
 
                if let Some(id) = seek_stmt(heap, false_body.body, f) {
 
                    return Some(id);
 
                }
 
            }
 
            None
 
        },
 
        Statement::While(stmt) => seek_stmt(heap, stmt.body, f),
 
        Statement::Synchronous(stmt) => seek_stmt(heap, stmt.body, f),
 
        _ => None
 
    };
 

	
 
    matched
 
}
 

	
 
fn seek_scope<F: Fn(&Scope) -> bool>(heap: &Heap, start: ScopeId, f: &F) -> Option<ScopeId> {
 
    let scope = &heap[start];
 
    if f(scope) { return Some(start); }
 

	
 
    for child_scope_id in scope.nested.iter().copied() {
 
        if let Some(result) = seek_scope(heap, child_scope_id, f) {
 
            return Some(result);
 
        }
 
    }
 

	
 
    return None;
 
}
 

	
 
fn seek_expr_in_expr<F: Fn(&Expression) -> bool>(heap: &Heap, start: ExpressionId, f: &F) -> Option<ExpressionId> {
 
    let expr = &heap[start];
 
    if f(expr) { return Some(start); }
 

	
 
    match expr {
 
        Expression::Assignment(expr) => {
 
            None
 
            .or_else(|| seek_expr_in_expr(heap, expr.left, f))
 
            .or_else(|| seek_expr_in_expr(heap, expr.right, f))
 
        },
 
        Expression::Binding(expr) => {
 
            None
 
            .or_else(|| seek_expr_in_expr(heap, expr.bound_to, f))
 
            .or_else(|| seek_expr_in_expr(heap, expr.bound_from, f))
 
        }
 
        Expression::Conditional(expr) => {
 
            None
 
            .or_else(|| seek_expr_in_expr(heap, expr.test, f))
 
            .or_else(|| seek_expr_in_expr(heap, expr.true_expression, f))
 
            .or_else(|| seek_expr_in_expr(heap, expr.false_expression, f))
 
        },
 
        Expression::Binary(expr) => {
 
            None
 
            .or_else(|| seek_expr_in_expr(heap, expr.left, f))
 
            .or_else(|| seek_expr_in_expr(heap, expr.right, f))
 
        },
 
        Expression::Unary(expr) => {
 
            seek_expr_in_expr(heap, expr.expression, f)
 
        },
 
        Expression::Indexing(expr) => {
 
            None
 
            .or_else(|| seek_expr_in_expr(heap, expr.subject, f))
 
            .or_else(|| seek_expr_in_expr(heap, expr.index, f))
 
        },
 
        Expression::Slicing(expr) => {
 
            None
 
            .or_else(|| seek_expr_in_expr(heap, expr.subject, f))
 
            .or_else(|| seek_expr_in_expr(heap, expr.from_index, f))
 
            .or_else(|| seek_expr_in_expr(heap, expr.to_index, f))
 
        },
 
        Expression::Select(expr) => {
 
            seek_expr_in_expr(heap, expr.subject, f)
 
        },
 
        Expression::Literal(expr) => {
 
            if let Literal::Struct(lit) = &expr.value {
 
                for field in &lit.fields {
 
                    if let Some(id) = seek_expr_in_expr(heap, field.value, f) {
 
                        return Some(id)
 
                    }
 
                }
 
            } else if let Literal::Array(elements) = &expr.value {
 
                for element in elements {
 
                    if let Some(id) = seek_expr_in_expr(heap, *element, f) {
 
                        return Some(id)
 
                    }
 
                }
 
            }
 
            None
 
        },
 
        Expression::Cast(expr) => {
 
            seek_expr_in_expr(heap, expr.subject, f)
 
        }
 
        Expression::Call(expr) => {
 
            for arg in &expr.arguments {
 
                if let Some(id) = seek_expr_in_expr(heap, *arg, f) {
 
                    return Some(id)
 
                }
 
            }
 
            None
 
        },
 
        Expression::Variable(_expr) => {
 
            None
 
        }
 
    }
 
}
 

	
 
fn seek_expr_in_stmt<F: Fn(&Expression) -> bool>(heap: &Heap, start: StatementId, f: &F) -> Option<ExpressionId> {
 
    let stmt = &heap[start];
 

	
 
    match stmt {
 
        Statement::Local(stmt) => {
 
            match stmt {
 
                LocalStatement::Memory(stmt) => seek_expr_in_expr(heap, stmt.initial_expr.upcast(), f),
 
                LocalStatement::Channel(_) => None
 
            }
 
        }
 
        Statement::Block(stmt) => {
 
            for stmt_id in &stmt.statements {
 
                if let Some(id) = seek_expr_in_stmt(heap, *stmt_id, f) {
 
                    return Some(id)
 
                }
 
            }
 
            None
 
        },
 
        Statement::Labeled(stmt) => {
 
            seek_expr_in_stmt(heap, stmt.body, f)
 
        },
 
        Statement::If(stmt) => {
 
            None
 
            .or_else(|| seek_expr_in_expr(heap, stmt.test, f))
 
            .or_else(|| seek_expr_in_stmt(heap, stmt.true_case.body, f))
 
            .or_else(|| if let Some(false_body) = stmt.false_case {
 
                seek_expr_in_stmt(heap, false_body.body, f)
 
            } else {
 
                None
 
            })
 
        },
 
        Statement::While(stmt) => {
 
            None
 
            .or_else(|| seek_expr_in_expr(heap, stmt.test, f))
 
            .or_else(|| seek_expr_in_stmt(heap, stmt.body, f))
 
        },
 
        Statement::Synchronous(stmt) => {
 
            seek_expr_in_stmt(heap, stmt.body, f)
 
        },
 
        Statement::Return(stmt) => {
 
            for expr_id in &stmt.expressions {
 
                if let Some(id) = seek_expr_in_expr(heap, *expr_id, f) {
 
                    return Some(id);
 
                }
 
            }
 
            None
 
        },
 
        Statement::New(stmt) => {
 
            seek_expr_in_expr(heap, stmt.expression.upcast(), f)
 
        },
 
        Statement::Expression(stmt) => {
 
            seek_expr_in_expr(heap, stmt.expression, f)
 
        },
 
        _ => None
 
    }
 
}
 

	
 
struct FakeRunContext{}
 
impl RunContext for FakeRunContext {
 
    fn performed_put(&mut self, _port: PortId) -> bool { unreachable!() }
 
    fn performed_get(&mut self, _port: PortId) -> Option<ValueGroup> { unreachable!() }
 
    fn fires(&mut self, _port: PortId) -> Option<Value> { unreachable!() }
 
    fn performed_fork(&mut self) -> Option<bool> { unreachable!() }
 
    fn created_channel(&mut self) -> Option<(Value, Value)> { unreachable!() }
 
    fn performed_select_start(&mut self) -> bool { unreachable!() }
 
    fn performed_select_register_port(&mut self) -> bool { unreachable!() }
 
    fn performed_select_wait(&mut self) -> Option<u32> { unreachable!() }
 
}
 
\ No newline at end of file
src/random.rs
Show inline comments
 
new file 100644
 
/**
 
 * random.rs
 
 *
 
 * Simple wrapper over a random number generator. Put here so that we can have
 
 * a feature flag for particular forms of randomness. For now we'll use pseudo-
 
 * randomness since that will help debugging.
 
 */
 

	
 
use rand::{RngCore, SeedableRng};
 
use rand_pcg;
 

	
 
pub(crate) struct Random {
 
    rng: rand_pcg::Lcg64Xsh32,
 
}
 

	
 
impl Random {
 
    pub(crate) fn new() -> Self {
 
        use std::time::SystemTime;
 

	
 
        let now = SystemTime::now();
 
        let elapsed = match now.duration_since(SystemTime::UNIX_EPOCH) {
 
            Ok(elapsed) => elapsed,
 
            Err(err) => err.duration(),
 
        };
 

	
 
        let elapsed = elapsed.as_nanos();
 
        let seed = elapsed.to_le_bytes();
 

	
 
        return Self::new_seeded(seed);
 
    }
 

	
 
    pub(crate) fn new_seeded(seed: [u8; 16]) -> Self {
 
        return Self{ rng: rand_pcg::Pcg32::from_seed(seed) }
 
    }
 

	
 
    pub(crate) fn get_u64(&mut self) -> u64 {
 
        return self.rng.next_u64();
 
    }
 
}
 
\ No newline at end of file
src/runtime/connector.rs
Show inline comments
 
// connector.rs
 
//
 
// Represents a component. A component (and the scheduler that is running it)
 
// has many properties that are not easy to subdivide into aspects that are
 
// conceptually handled by particular data structures. That is to say: the code
 
// that we run governs: running PDL code, keeping track of ports, instantiating
 
// new components and transports (i.e. interacting with the runtime), running
 
// a consensus algorithm, etc. But on the other hand, our data is rather
 
// simple: we have a speculative execution tree, a set of ports that we own,
 
// and a bit of code that we should run.
 
//
 
// So currently the code is organized as following:
 
// - The scheduler that is running the component is the authoritative source on
 
//     ports during *non-sync* mode. The consensus algorithm is the
 
//     authoritative source during *sync* mode. They retrieve each other's
 
//     state during the transitions. Hence port data exists duplicated between
 
//     these two datastructures.
 
// - The execution tree is where executed branches reside. But the execution
 
//     tree is only aware of the tree shape itself (and keeps track of some
 
//     queues of branches that are in a particular state), and tends to store
 
//     the PDL program state. The consensus algorithm is also somewhat aware
 
//     of the execution tree, but only in terms of what is needed to complete
 
//     a sync round (for now, that means the port mapping in each branch).
 
//     Hence once more we have properties conceptually associated with branches
 
//     in two places.
 
// - TODO: Write about handling messages, consensus wrapping data
 
// - TODO: Write about way information is exchanged between PDL/component and scheduler through ctx
 

	
 
use std::sync::atomic::AtomicBool;
 

	
 
use crate::ProtocolDescription;
 
use crate::protocol::eval::{EvalContinuation, EvalError, Prompt, Value, PortId, ValueGroup};
 
use crate::protocol::RunContext;
 

	
 
use super::branch::{BranchId, ExecTree, QueueKind, SpeculativeState, PreparedStatement};
 
use super::consensus::{Consensus, Consistency, RoundConclusion, find_ports_in_value_group};
 
use super::inbox::{DataMessage, Message, SyncCompMessage, SyncPortMessage, SyncControlMessage, PublicInbox};
 
use super::native::Connector;
 
use super::port::{PortKind, PortIdLocal};
 
use super::scheduler::{ComponentCtx, SchedulerCtx, MessageTicket};
 

	
 
pub(crate) struct ConnectorPublic {
 
    pub inbox: PublicInbox,
 
    pub sleeping: AtomicBool,
 
}
 

	
 
impl ConnectorPublic {
 
    pub fn new(initialize_as_sleeping: bool) -> Self {
 
        ConnectorPublic{
 
            inbox: PublicInbox::new(),
 
            sleeping: AtomicBool::new(initialize_as_sleeping),
 
        }
 
    }
 
}
 

	
 
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
 
enum Mode {
 
    NonSync,    // running non-sync code
 
    Sync,       // running sync code (in potentially multiple branches)
 
    SyncError,  // encountered an unrecoverable error in sync mode
 
    Error,      // encountered an error in non-sync mode (or finished handling the sync mode error).
 
}
 

	
 
#[derive(Debug)]
 
pub(crate) enum ConnectorScheduling {
 
    Immediate,          // Run again, immediately
 
    Later,              // Schedule for running, at some later point in time
 
    NotNow,             // Do not reschedule for running
 
    Exit,               // Connector has exited
 
}
 

	
 
pub(crate) struct ConnectorPDL {
 
    mode: Mode,
 
    eval_error: Option<EvalError>,
 
    tree: ExecTree,
 
    consensus: Consensus,
 
    last_finished_handled: Option<BranchId>,
 
}
 

	
 
struct ConnectorRunContext<'a> {
 
    branch_id: BranchId,
 
    consensus: &'a Consensus,
 
    prepared: PreparedStatement,
 
}
 

	
 
impl<'a> RunContext for ConnectorRunContext<'a>{
 
    fn performed_put(&mut self, _port: PortId) -> bool {
 
        return match self.prepared.take() {
 
            PreparedStatement::None => false,
 
            PreparedStatement::PerformedPut => true,
 
            taken => unreachable!("prepared statement is '{:?}' during 'performed_put()'", taken)
 
        };
 
    }
 

	
 
    fn performed_get(&mut self, _port: PortId) -> Option<ValueGroup> {
 
        return match self.prepared.take() {
 
            PreparedStatement::None => None,
 
            PreparedStatement::PerformedGet(value) => Some(value),
 
            taken => unreachable!("prepared statement is '{:?}' during 'performed_get()'", taken),
 
        };
 
    }
 

	
 
    fn fires(&mut self, _port: PortId) -> Option<Value> {
 
        todo!("Remove fires() now")
 
        // let port_id = PortIdLocal::new(port.id);
 
        // let annotation = self.consensus.get_annotation(self.branch_id, port_id);
 
        // return annotation.expected_firing.map(|v| Value::Bool(v));
 
    }
 

	
 
    fn created_channel(&mut self) -> Option<(Value, Value)> {
 
        return match self.prepared.take() {
 
            PreparedStatement::None => None,
 
            PreparedStatement::CreatedChannel(ports) => Some(ports),
 
            taken => unreachable!("prepared statement is '{:?}' during 'created_channel()'", taken),
 
        };
 
    }
 

	
 
    fn performed_fork(&mut self) -> Option<bool> {
 
        return match self.prepared.take() {
 
            PreparedStatement::None => None,
 
            PreparedStatement::ForkedExecution(path) => Some(path),
 
            taken => unreachable!("prepared statement is '{:?}' during 'performed_fork()'", taken),
 
        };
 
    }
 

	
 
    fn performed_select_start(&mut self) -> bool { unreachable!() }
 
    fn performed_select_register_port(&mut self) -> bool { unreachable!() }
 
    fn performed_select_wait(&mut self) -> Option<u32> { unreachable!() }
 
}
 

	
 
impl Connector for ConnectorPDL {
 
    fn run(&mut self, sched_ctx: SchedulerCtx, comp_ctx: &mut ComponentCtx) -> ConnectorScheduling {
 
        if let Some(scheduling) = self.handle_new_messages(comp_ctx) {
 
            return scheduling;
 
        }
 

	
 
        match self.mode {
 
            Mode::Sync => {
 
                // Run in sync mode
 
                let scheduling = self.run_in_sync_mode(sched_ctx, comp_ctx);
 

	
 
                // Handle any new finished branches
 
                let mut iter_id = self.last_finished_handled.or(self.tree.get_queue_first(QueueKind::FinishedSync));
 
                while let Some(branch_id) = iter_id {
 
                    iter_id = self.tree.get_queue_next(branch_id);
 
                    self.last_finished_handled = Some(branch_id);
 

	
 
                    if let Some(round_conclusion) = self.consensus.handle_new_finished_sync_branch(branch_id, comp_ctx) {
 
                        // Actually found a solution
 
                        return self.enter_non_sync_mode(round_conclusion, comp_ctx);
 
                    }
 

	
 
                    self.last_finished_handled = Some(branch_id);
 
                }
 

	
 
                return scheduling;
 
            },
 
            Mode::NonSync => {
 
                let scheduling = self.run_in_deterministic_mode(sched_ctx, comp_ctx);
 
                return scheduling;
 
            },
 
            Mode::SyncError => {
 
                let scheduling = self.run_in_sync_mode(sched_ctx, comp_ctx);
 
                return scheduling;
 
            },
 
            Mode::Error => {
 
                // This shouldn't really be called. Because when we reach exit
 
                // mode the scheduler should not run the component anymore
 
                unreachable!("called component run() during error-mode");
 
            },
 
        }
 
    }
 
}
 

	
 
impl ConnectorPDL {
 
    pub fn new(initial: Prompt) -> Self {
 
        Self{
 
            mode: Mode::NonSync,
 
            eval_error: None,
 
            tree: ExecTree::new(initial),
 
            consensus: Consensus::new(),
 
            last_finished_handled: None,
 
        }
 
    }
 

	
 
    // --- Handling messages
 

	
 
    pub fn handle_new_messages(&mut self, ctx: &mut ComponentCtx) -> Option<ConnectorScheduling> {
 
        while let Some(ticket) = ctx.get_next_message_ticket() {
 
            let message = ctx.read_message_using_ticket(ticket);
 
            let immediate_result = if let Message::Data(_) = message {
 
                self.handle_new_data_message(ticket, ctx);
 
                None
 
            } else {
 
                match ctx.take_message_using_ticket(ticket) {
 
                    Message::Data(_) => unreachable!(),
 
                    Message::SyncComp(message) => {
 
                        self.handle_new_sync_comp_message(message, ctx)
 
                    },
 
                    Message::SyncPort(message) => {
 
                        self.handle_new_sync_port_message(message, ctx);
 
                        None
 
                    },
 
                    Message::SyncControl(message) => {
 
                        self.handle_new_sync_control_message(message, ctx)
 
                    },
 
                    Message::Control(_) => unreachable!("control message in component"),
 
                }
 
            };
 

	
 
            if let Some(result) = immediate_result {
 
                return Some(result);
 
            }
 
        }
 

	
 
        return None;
 
    }
 

	
 
    pub fn handle_new_data_message(&mut self, ticket: MessageTicket, ctx: &mut ComponentCtx) {
 
        // Go through all branches that are awaiting new messages and see if
 
        // there is one that can receive this message.
 
        if !self.consensus.handle_new_data_message(ticket, ctx) {
 
            // Message should not be handled now
 
            return;
 
        }
 

	
 
        let message = ctx.read_message_using_ticket(ticket).as_data();
 
        let mut iter_id = self.tree.get_queue_first(QueueKind::AwaitingMessage);
 
        while let Some(branch_id) = iter_id {
 
            iter_id = self.tree.get_queue_next(branch_id);
 

	
 
            let branch = &self.tree[branch_id];
 
            if branch.awaiting_port != message.data_header.target_port { continue; }
 
            if !self.consensus.branch_can_receive(branch_id, &message) { continue; }
 

	
 
            // This branch can receive, so fork and given it the message
 
            let receiving_branch_id = self.tree.fork_branch(branch_id);
 
            self.consensus.notify_of_new_branch(branch_id, receiving_branch_id);
 
            let receiving_branch = &mut self.tree[receiving_branch_id];
 

	
 
            debug_assert!(receiving_branch.awaiting_port == message.data_header.target_port);
 
            receiving_branch.awaiting_port = PortIdLocal::new_invalid();
 
            receiving_branch.prepared = PreparedStatement::PerformedGet(message.content.clone());
 
            self.consensus.notify_of_received_message(receiving_branch_id, &message, ctx);
 

	
 
            // And prepare the branch for running
 
            self.tree.push_into_queue(QueueKind::Runnable, receiving_branch_id);
 
        }
 
    }
 

	
 
    pub fn handle_new_sync_comp_message(&mut self, message: SyncCompMessage, ctx: &mut ComponentCtx) -> Option<ConnectorScheduling> {
 
        if let Some(round_conclusion) = self.consensus.handle_new_sync_comp_message(message, ctx) {
 
            return Some(self.enter_non_sync_mode(round_conclusion, ctx));
 
        }
 

	
 
        return None;
 
    }
 

	
 
    pub fn handle_new_sync_port_message(&mut self, message: SyncPortMessage, ctx: &mut ComponentCtx) {
 
        self.consensus.handle_new_sync_port_message(message, ctx);
 
    }
 

	
 
    pub fn handle_new_sync_control_message(&mut self, message: SyncControlMessage, ctx: &mut ComponentCtx) -> Option<ConnectorScheduling> {
 
        if let Some(round_conclusion) = self.consensus.handle_new_sync_control_message(message, ctx) {
 
            return Some(self.enter_non_sync_mode(round_conclusion, ctx));
 
        }
 

	
 
        return None;
 
    }
 

	
 
    // --- Running code
 

	
 
    pub fn run_in_sync_mode(&mut self, sched_ctx: SchedulerCtx, comp_ctx: &mut ComponentCtx) -> ConnectorScheduling {
 
        // Check if we have any branch that needs running
 
        debug_assert!(self.tree.is_in_sync() && self.consensus.is_in_sync());
 
        let branch_id = self.tree.pop_from_queue(QueueKind::Runnable);
 
        if branch_id.is_none() {
 
            return ConnectorScheduling::NotNow;
 
        }
 

	
 
        // Retrieve the branch and run it
 
        let branch_id = branch_id.unwrap();
 
        let branch = &mut self.tree[branch_id];
 

	
 
        let mut run_context = ConnectorRunContext{
 
            branch_id,
 
            consensus: &self.consensus,
 
            prepared: branch.prepared.take(),
 
        };
 

	
 
        let run_result = Self::run_prompt(&mut branch.code_state, &sched_ctx.runtime.protocol_description, &mut run_context);
 
        if let Err(eval_error) = run_result {
 
            self.eval_error = Some(eval_error);
 
            self.mode = Mode::SyncError;
 
            if let Some(conclusion) = self.consensus.notify_of_fatal_branch(branch_id, comp_ctx) {
 
                // We can exit immediately
 
                return self.enter_non_sync_mode(conclusion, comp_ctx);
 
            } else {
 
                // Current branch failed. But we may have other things that are
 
                // running.
 
                return ConnectorScheduling::Immediate;
 
            }
 
        }
 
        let run_result = run_result.unwrap();
 

	
 
        // Handle the returned result. Note that this match statement contains
 
        // explicit returns in case the run result requires that the component's
 
        // code is ran again immediately
 
        match run_result {
 
            EvalContinuation::BranchInconsistent => {
 
                // Branch became inconsistent
 
                branch.sync_state = SpeculativeState::Inconsistent;
 
            },
 
            EvalContinuation::BlockFires(port_id) => {
 
                // Branch called `fires()` on a port that has not been used yet.
 
                let port_id = PortIdLocal::new(port_id.id);
 

	
 
                // Create two forks, one that assumes the port will fire, and
 
                // one that assumes the port remains silent
src/runtime2/component/component_context.rs
Show inline comments
 
use crate::runtime2::scheduler::*;
 
use crate::runtime2::runtime::*;
 
use crate::runtime2::communication::*;
 

	
 
#[derive(Debug)]
 
pub struct Port {
 
    pub self_id: PortId,
 
    pub peer_comp_id: CompId, // eventually consistent
 
    pub peer_port_id: PortId, // eventually consistent
 
    pub kind: PortKind,
 
    pub state: PortState,
 
    #[cfg(debug_assertions)] pub(crate) associated_with_peer: bool,
 
}
 

	
 
pub struct Peer {
 
    pub id: CompId,
 
    pub num_associated_ports: u32,
 
    pub(crate) handle: CompHandle,
 
}
 

	
 
/// Port and peer management structure. Will keep a local reference counter to
 
/// the ports associate with peers, additionally manages the atomic reference
 
/// counter associated with the peers' component handles.
 
pub struct CompCtx {
 
    pub id: CompId,
 
    ports: Vec<Port>,
 
    peers: Vec<Peer>,
 
    port_id_counter: u32,
 
}
 

	
 
#[derive(Copy, Clone)]
 
#[derive(Copy, Clone, PartialEq, Eq)]
 
pub struct LocalPortHandle(PortId);
 

	
 
#[derive(Copy, Clone)]
 
pub struct LocalPeerHandle(CompId);
 

	
 
impl CompCtx {
 
    /// Creates a new component context based on a reserved entry in the
 
    /// component store. This reservation is used such that we already know our
 
    /// assigned ID.
 
    pub(crate) fn new(reservation: &CompReserved) -> Self {
 
        return Self{
 
            id: reservation.id(),
 
            ports: Vec::new(),
 
            peers: Vec::new(),
 
            port_id_counter: 0,
 
        }
 
    }
 

	
 
    /// Creates a new channel that is fully owned by the component associated
 
    /// with this context.
 
    pub(crate) fn create_channel(&mut self) -> Channel {
 
        let putter_id = PortId(self.take_port_id());
 
        let getter_id = PortId(self.take_port_id());
 
        self.ports.push(Port{
 
            self_id: putter_id,
 
            peer_port_id: getter_id,
 
            kind: PortKind::Putter,
 
            state: PortState::Open,
 
            peer_comp_id: self.id,
 
            associated_with_peer: false,
 
        });
 
        self.ports.push(Port{
 
            self_id: getter_id,
 
            peer_port_id: putter_id,
 
            kind: PortKind::Getter,
 
            state: PortState::Open,
 
            peer_comp_id: self.id,
 
            associated_with_peer: false,
 
        });
 

	
 
        return Channel{ putter_id, getter_id };
 
    }
 

	
 
    /// Adds a new port. Make sure to call `add_peer` afterwards.
 
    pub(crate) fn add_port(&mut self, peer_comp_id: CompId, peer_port_id: PortId, kind: PortKind, state: PortState) -> LocalPortHandle {
 
        let self_id = PortId(self.take_port_id());
 
        self.ports.push(Port{
 
            self_id, peer_comp_id, peer_port_id, kind, state,
 
            #[cfg(debug_assertions)] associated_with_peer: false,
 
        });
 
        return LocalPortHandle(self_id);
 
    }
 

	
 
    /// Removes a port. Make sure you called `remove_peer` first.
 
    pub(crate) fn remove_port(&mut self, port_handle: LocalPortHandle) -> Port {
 
        let port_index = self.must_get_port_index(port_handle);
 
        let port = self.ports.remove(port_index);
 
        debug_assert!(!port.associated_with_peer);
 
        return port;
 
    }
 

	
 
    /// Adds a new peer. This must be called for every port, no matter the
 
    /// component the channel is connected to. If a `CompHandle` is supplied,
 
    /// then it will be used to add the peer. Otherwise it will be retrieved
 
    /// from the runtime using its ID.
 
    pub(crate) fn add_peer(&mut self, port_handle: LocalPortHandle, sched_ctx: &SchedulerCtx, peer_comp_id: CompId, handle: Option<&CompHandle>) {
 
        let self_id = self.id;
 
        let port = self.get_port_mut(port_handle);
 
        debug_assert_eq!(port.peer_comp_id, peer_comp_id);
 
        debug_assert!(!port.associated_with_peer);
 
        if !Self::requires_peer_reference(port, self_id, false) {
 
            return;
 
        }
 

	
 
        dbg_code!(port.associated_with_peer = true);
 
        match self.get_peer_index_by_id(peer_comp_id) {
 
            Some(peer_index) => {
 
                let peer = &mut self.peers[peer_index];
 
                peer.num_associated_ports += 1;
 
            },
 
            None => {
 
                let handle = match handle {
 
                    Some(handle) => handle.clone(),
 
                    None => sched_ctx.runtime.get_component_public(peer_comp_id)
 
                };
 
                self.peers.push(Peer{
 
                    id: peer_comp_id,
 
                    num_associated_ports: 1,
 
                    handle,
 
                });
 
            }
 
        }
 
    }
 

	
 
    /// Removes a peer associated with a port.
 
    pub(crate) fn remove_peer(&mut self, sched_ctx: &SchedulerCtx, port_handle: LocalPortHandle, peer_id: CompId, also_remove_if_closed: bool) {
 
        let self_id = self.id;
 
        let port = self.get_port_mut(port_handle);
 
        debug_assert_eq!(port.peer_comp_id, peer_id);
 
        if !Self::requires_peer_reference(port, self_id, also_remove_if_closed) {
 
            return;
 
        }
 

	
 
        debug_assert!(port.associated_with_peer);
 
        dbg_code!(port.associated_with_peer = false);
 
        let peer_index = self.get_peer_index_by_id(peer_id).unwrap();
 
        let peer = &mut self.peers[peer_index];
 
        peer.num_associated_ports -= 1;
 
        if peer.num_associated_ports == 0 {
 
            let mut peer = self.peers.remove(peer_index);
 
            if let Some(key) = peer.handle.decrement_users() {
 
                debug_assert_ne!(key.downgrade(), self.id); // should be upheld by the code that shuts down a component
 
                sched_ctx.runtime.destroy_component(key);
 
            }
 
        }
 
    }
 

	
 
    pub(crate) fn set_port_state(&mut self, port_handle: LocalPortHandle, new_state: PortState) {
 
        let port_info = self.get_port_mut(port_handle);
 
        debug_assert_ne!(port_info.state, PortState::Closed); // because then we do not expect to change the state
 
        port_info.state = new_state;
 
    }
 

	
 
    pub(crate) fn get_port_handle(&self, port_id: PortId) -> LocalPortHandle {
 
        return LocalPortHandle(port_id);
 
    }
 

	
 
    // should perhaps be revised, used in main inbox
 
    pub(crate) fn get_port_index(&self, port_handle: LocalPortHandle) -> usize {
 
        return self.must_get_port_index(port_handle);
 
    }
 

	
 
    pub(crate) fn get_peer_handle(&self, peer_id: CompId) -> LocalPeerHandle {
 
        return LocalPeerHandle(peer_id);
 
    }
 

	
 
    pub(crate) fn get_port(&self, port_handle: LocalPortHandle) -> &Port {
 
        let index = self.must_get_port_index(port_handle);
 
        return &self.ports[index];
 
    }
 

	
 
    pub(crate) fn get_port_mut(&mut self, port_handle: LocalPortHandle) -> &mut Port {
 
        let index = self.must_get_port_index(port_handle);
 
        return &mut self.ports[index];
 
    }
 

	
 
    pub(crate) fn get_port_by_index_mut(&mut self, index: usize) -> &mut Port {
 
        return &mut self.ports[index];
 
    }
 

	
 
    pub(crate) fn get_peer(&self, peer_handle: LocalPeerHandle) -> &Peer {
 
        let index = self.must_get_peer_index(peer_handle);
 
        return &self.peers[index];
 
    }
 

	
 
    pub(crate) fn get_peer_mut(&mut self, peer_handle: LocalPeerHandle) -> &mut Peer {
 
        let index = self.must_get_peer_index(peer_handle);
 
        return &mut self.peers[index];
 
    }
 

	
 
    #[inline]
 
    pub(crate) fn iter_ports(&self) -> impl Iterator<Item=&Port> {
 
        return self.ports.iter();
 
    }
 

	
 
    #[inline]
 
    pub(crate) fn iter_ports_mut(&mut self) -> impl Iterator<Item=&mut Port> {
 
        return self.ports.iter_mut();
 
    }
 

	
 
    #[inline]
 
    pub(crate) fn iter_peers(&self) -> impl Iterator<Item=&Peer> {
 
        return self.peers.iter();
 
    }
 

	
 
    #[inline]
 
    pub(crate) fn num_ports(&self) -> usize {
 
        return self.ports.len();
 
    }
 

	
 
    // -------------------------------------------------------------------------
 
    // Local utilities
 
    // -------------------------------------------------------------------------
 

	
 
    #[inline]
 
    fn requires_peer_reference(port: &Port, self_id: CompId, required_if_closed: bool) -> bool {
 
        return (port.state != PortState::Closed || required_if_closed) && port.peer_comp_id != self_id;
 
    }
 

	
 
    fn must_get_port_index(&self, handle: LocalPortHandle) -> usize {
 
        for (index, port) in self.ports.iter().enumerate() {
 
            if port.self_id == handle.0 {
src/runtime2/component/component_pdl.rs
Show inline comments
 
use crate::random::Random;
 
use crate::protocol::*;
 
use crate::protocol::ast::ProcedureDefinitionId;
 
use crate::protocol::eval::{
 
    PortId as EvalPortId, Prompt,
 
    ValueGroup, Value,
 
    EvalContinuation, EvalResult, EvalError
 
};
 

	
 
use crate::runtime2::scheduler::SchedulerCtx;
 
use crate::runtime2::communication::*;
 

	
 
use super::component_context::*;
 
use super::control_layer::*;
 
use super::consensus::Consensus;
 

	
 
pub enum CompScheduling {
 
    Immediate,
 
    Requeue,
 
    Sleep,
 
    Exit,
 
}
 

	
 
pub enum ExecStmt {
 
    CreatedChannel((Value, Value)),
 
    PerformedPut,
 
    PerformedGet(ValueGroup),
 
    PerformedSelectStart,
 
    PerformedSelectRegister,
 
    PerformedSelectWait(u32),
 
    None,
 
}
 

	
 
impl ExecStmt {
 
    fn take(&mut self) -> ExecStmt {
 
        let mut value = ExecStmt::None;
 
        std::mem::swap(self, &mut value);
 
        return value;
 
    }
 

	
 
    fn is_none(&self) -> bool {
 
        match self {
 
            ExecStmt::None => return true,
 
            _ => return false,
 
        }
 
    }
 
}
 

	
 
pub struct ExecCtx {
 
    stmt: ExecStmt,
 
}
 

	
 
impl RunContext for ExecCtx {
 
    fn performed_put(&mut self, _port: EvalPortId) -> bool {
 
        match self.stmt.take() {
 
            ExecStmt::None => return false,
 
            ExecStmt::PerformedPut => return true,
 
            _ => unreachable!(),
 
        }
 
    }
 

	
 
    fn performed_get(&mut self, _port: EvalPortId) -> Option<ValueGroup> {
 
        match self.stmt.take() {
 
            ExecStmt::None => return None,
 
            ExecStmt::PerformedGet(value) => return Some(value),
 
            _ => unreachable!(),
 
        }
 
    }
 

	
 
    fn fires(&mut self, _port: EvalPortId) -> Option<Value> {
 
        todo!("remove fires")
 
    }
 

	
 
    fn performed_fork(&mut self) -> Option<bool> {
 
        todo!("remove fork")
 
    }
 

	
 
    fn created_channel(&mut self) -> Option<(Value, Value)> {
 
        match self.stmt.take() {
 
            ExecStmt::None => return None,
 
            ExecStmt::CreatedChannel(ports) => return Some(ports),
 
            _ => unreachable!(),
 
        }
 
    }
 

	
 
    fn performed_select_start(&mut self) -> bool {
 
        match self.stmt.take() {
 
            ExecStmt::None => return false,
 
            ExecStmt::PerformedSelectStart => return true,
 
            _ => unreachable!(),
 
        }
 
    }
 

	
 
    fn performed_select_register_port(&mut self) -> bool {
 
        match self.stmt.take() {
 
            ExecStmt::None => return false,
 
            ExecStmt::PerformedSelectRegister => return true,
 
            _ => unreachable!(),
 
        }
 
    }
 

	
 
    fn performed_select_wait(&mut self) -> Option<u32> {
 
        match self.stmt.take() {
 
            ExecStmt::None => return None,
 
            ExecStmt::PerformedSelectWait(selected_case) => Some(selected_case),
 
            _ => unreachable!(),
 
            _v => unreachable!(),
 
        }
 
    }
 
}
 

	
 
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
 
pub(crate) enum Mode {
 
    NonSync, // not in sync mode
 
    Sync, // in sync mode, can interact with other components
 
    SyncEnd, // awaiting a solution, i.e. encountered the end of the sync block
 
    BlockedGet,
 
    BlockedPut,
 
    BlockedGet, // blocked because we need to receive a message on a particular port
 
    BlockedPut, // component is blocked because the port is blocked
 
    BlockedSelect, // waiting on message to complete the select statement
 
    StartExit, // temporary state: if encountered then we start the shutdown process
 
    BusyExit, // temporary state: waiting for Acks for all the closed ports
 
    Exit, // exiting: shutdown process started, now waiting until the reference count drops to 0
 
}
 

	
 
struct SelectCase {
 
    involved_ports: Vec<LocalPortHandle>,
 
}
 

	
 
// TODO: @Optimize, flatten cases into single array, have index-pointers to next case
 
struct SelectState {
 
    cases: Vec<SelectCase>,
 
    next_case: u32,
 
    num_cases: u32,
 
    random: Random,
 
    candidates_workspace: Vec<usize>,
 
}
 

	
 
enum SelectDecision {
 
    None,
 
    Case(u32), // contains case index, should be passed along to PDL code
 
}
 

	
 
type InboxMain = Vec<Option<DataMessage>>;
 

	
 
impl SelectState {
 
    fn new() -> Self {
 
        return Self{
 
            cases: Vec::new(),
 
            next_case: 0,
 
            num_cases: 0,
 
            random: Random::new(),
 
            candidates_workspace: Vec::new(),
 
        }
 
    }
 

	
 
    fn handle_select_start(&mut self, num_cases: u32) {
 
        self.cases.clear();
 
        self.next_case = 0;
 
        self.num_cases = num_cases;
 
    }
 

	
 
    /// Register a port as belonging to a particular case. As for correctness of
 
    /// PDL code one cannot register the same port twice, this function might
 
    /// return an error
 
    fn register_select_case_port(&mut self, comp_ctx: &CompCtx, case_index: u32, _port_index: u32, port_id: PortId) -> Result<(), PortId> {
 
        // Retrieve case and port handle
 
        self.ensure_at_case(case_index);
 
        let cur_case = &mut self.cases[case_index as usize];
 
        let port_handle = comp_ctx.get_port_handle(port_id);
 
        debug_assert_eq!(cur_case.involved_ports.len(), _port_index as usize);
 

	
 
        // Make sure port wasn't added before, we disallow having the same port
 
        // in the same select guard twice.
 
        if cur_case.involved_ports.contains(&port_handle) {
 
            return Err(port_id);
 
        }
 

	
 
        cur_case.involved_ports.push(port_handle);
 
        return Ok(());
 
    }
 

	
 
    /// Notification that all ports have been registered and we should now wait
 
    /// until the appropriate messages have come in.
 
    fn handle_select_waiting_point(&mut self, inbox: &InboxMain, comp_ctx: &CompCtx) -> SelectDecision {
 
        if self.num_cases != self.next_case {
 
            // This happens when there are >=1 select cases written at the end
 
            // of the select block.
 
            self.ensure_at_case(self.num_cases - 1);
 
        }
 

	
 
        return self.has_decision(inbox, comp_ctx);
 
    }
 

	
 
    fn handle_updated_inbox(&mut self, inbox: &InboxMain, comp_ctx: &CompCtx) -> SelectDecision {
 
        return self.has_decision(inbox, comp_ctx);
 
    }
 

	
 
    /// Internal helper, pushes empty cases inbetween last case and provided new
 
    /// case index.
 
    fn ensure_at_case(&mut self, new_case_index: u32) {
 
        // Push an empty case for all intermediate cases that were not
 
        // registered with a port.
 
        debug_assert!(new_case_index >= self.next_case && new_case_index < self.num_cases);
 
        for _ in self.next_case..new_case_index + 1 {
 
            self.cases.push(SelectCase{ involved_ports: Vec::new() });
 
        }
 
        self.next_case = new_case_index + 1;
 
    }
 

	
 
    /// Checks if a decision can be reached
 
    fn has_decision(&mut self, inbox: &InboxMain, comp_ctx: &CompCtx) -> SelectDecision {
 
        self.candidates_workspace.clear();
 
        if self.cases.is_empty() {
 
            // If there are no cases then we can immediately reach a "bogus
 
            // decision".
 
            return SelectDecision::Case(0);
 
        }
 

	
 
        // Need to check for valid case
 
        'case_loop: for (case_index, case) in self.cases.iter().enumerate() {
 
            for port_handle in case.involved_ports.iter().copied() {
 
                let port_index = comp_ctx.get_port_index(port_handle);
 
                if inbox[port_index].is_none() {
 
                    // Condition not satisfied
 
                    continue 'case_loop;
 
                }
 
            }
 

	
 
            // If here then the case guard is satisfied
 
            self.candidates_workspace.push(case_index);
 
        }
 

	
 
        if self.candidates_workspace.is_empty() {
 
            return SelectDecision::None;
 
        } else {
 
            let candidate_index = self.random.get_u64() as usize % self.candidates_workspace.len();
 
            return SelectDecision::Case(self.candidates_workspace[candidate_index] as u32);
 
        }
 
    }
 
}
 

	
 
pub(crate) struct CompPDL {
 
    pub mode: Mode,
 
    pub mode_port: PortId, // when blocked on a port
 
    pub mode_value: ValueGroup, // when blocked on a put
 
    select: SelectState,
 
    pub prompt: Prompt,
 
    pub control: ControlLayer,
 
    pub consensus: Consensus,
 
    pub sync_counter: u32,
 
    pub exec_ctx: ExecCtx,
 
    // TODO: Temporary field, simulates future plans of having one storage place
 
    //  reserved per port.
 
    // Should be same length as the number of ports. Corresponding indices imply
 
    // message is intended for that port.
 
    pub inbox_main: Vec<Option<DataMessage>>,
 
    pub inbox_main: InboxMain,
 
    pub inbox_backup: Vec<DataMessage>,
 
}
 

	
 
impl CompPDL {
 
    pub(crate) fn new(initial_state: Prompt, num_ports: usize) -> Self {
 
        let mut inbox_main = Vec::new();
 
        inbox_main.reserve(num_ports);
 
        for _ in 0..num_ports {
 
            inbox_main.push(None);
 
        }
 

	
 
        return Self{
 
            mode: Mode::NonSync,
 
            mode_port: PortId::new_invalid(),
 
            mode_value: ValueGroup::default(),
 
            select: SelectState::new(),
 
            prompt: initial_state,
 
            control: ControlLayer::default(),
 
            consensus: Consensus::new(),
 
            sync_counter: 0,
 
            exec_ctx: ExecCtx{
 
                stmt: ExecStmt::None,
 
            },
 
            inbox_main,
 
            inbox_backup: Vec::new(),
 
        }
 
    }
 

	
 
    pub(crate) fn handle_message(&mut self, sched_ctx: &mut SchedulerCtx, comp_ctx: &mut CompCtx, mut message: Message) {
 
        sched_ctx.log(&format!("handling message: {:#?}", message));
 
        if let Some(new_target) = self.control.should_reroute(&mut message) {
 
            let mut target = sched_ctx.runtime.get_component_public(new_target);
 
            target.send_message(sched_ctx, message, false); // not waking up: we schedule once we've received all PortPeerChanged Acks
 
            let _should_remove = target.decrement_users();
 
            debug_assert!(_should_remove.is_none());
 
            return;
 
        }
 

	
 
        match message {
 
            Message::Data(message) => {
 
                self.handle_incoming_data_message(sched_ctx, comp_ctx, message);
 
            },
 
            Message::Control(message) => {
 
                self.handle_incoming_control_message(sched_ctx, comp_ctx, message);
 
            },
 
            Message::Sync(message) => {
 
                self.handle_incoming_sync_message(sched_ctx, comp_ctx, message);
 
            }
 
        }
 
    }
 

	
 
    // -------------------------------------------------------------------------
 
    // Running component and handling changes in global component state
 
    // -------------------------------------------------------------------------
 

	
 
    pub(crate) fn run(&mut self, sched_ctx: &mut SchedulerCtx, comp_ctx: &mut CompCtx) -> Result<CompScheduling, EvalError> {
 
        use EvalContinuation as EC;
 

	
 
        sched_ctx.log(&format!("Running component (mode: {:?})", self.mode));
 

	
 
        // Depending on the mode don't do anything at all, take some special
 
        // actions, or fall through and run the PDL code.
 
        match self.mode {
 
            Mode::NonSync | Mode::Sync => {},
 
            Mode::NonSync | Mode::Sync | Mode::BlockedSelect => {
 
                // continue and run PDL code
 
            },
 
            Mode::SyncEnd | Mode::BlockedGet | Mode::BlockedPut => {
 
                return Ok(CompScheduling::Sleep);
 
            }
 
            Mode::StartExit => {
 
                self.handle_component_exit(sched_ctx, comp_ctx);
 
                return Ok(CompScheduling::Immediate);
 
            },
 
            Mode::BusyExit => {
 
                if self.control.has_acks_remaining() {
 
                    return Ok(CompScheduling::Sleep);
 
                } else {
 
                    self.mode = Mode::Exit;
 
                    return Ok(CompScheduling::Exit);
 
                }
 
            },
 
            Mode::Exit => {
 
                return Ok(CompScheduling::Exit);
 
            }
 
        }
 

	
 
        let run_result = self.execute_prompt(&sched_ctx)?;
 

	
 
        match run_result {
 
            EC::Stepping => unreachable!(), // execute_prompt runs until this is no longer returned
 
            EC::BranchInconsistent | EC::NewFork | EC::BlockFires(_) => todo!("remove these"),
 
            // Results that can be returned in sync mode
 
            EC::SyncBlockEnd => {
 
                debug_assert_eq!(self.mode, Mode::Sync);
 
                self.handle_sync_end(sched_ctx, comp_ctx);
 
                return Ok(CompScheduling::Immediate);
 
            },
 
            EC::BlockGet(port_id) => {
 
                debug_assert_eq!(self.mode, Mode::Sync);
 
                debug_assert!(self.exec_ctx.stmt.is_none());
 

	
 
                let port_id = port_id_from_eval(port_id);
 
                let port_handle = comp_ctx.get_port_handle(port_id);
 
                let port_index = comp_ctx.get_port_index(port_handle);
 
                if let Some(message) = &self.inbox_main[port_index] {
 
                    // Check if we can actually receive the message
 
                    if self.consensus.try_receive_data_message(sched_ctx, comp_ctx, message) {
 
                        // Message was received. Make sure any blocked peers and
 
                        // pending messages are handled.
 
                        let message = self.inbox_main[port_index].take().unwrap();
 
                        self.handle_received_data_message(sched_ctx, comp_ctx, port_handle);
 

	
 
                        self.exec_ctx.stmt = ExecStmt::PerformedGet(message.content);
 
                        return Ok(CompScheduling::Immediate);
 
                    } else {
 
                        todo!("handle sync failure due to message deadlock");
 
                        return Ok(CompScheduling::Sleep);
 
                    }
 
                } else {
 
                    // We need to wait
 
                    self.mode = Mode::BlockedGet;
 
                    self.mode_port = port_id;
 
                    return Ok(CompScheduling::Sleep);
 
                }
 
            },
 
            EC::Put(port_id, value) => {
 
                debug_assert_eq!(self.mode, Mode::Sync);
 
                sched_ctx.log(&format!("Putting value {:?}", value));
 
                let port_id = port_id_from_eval(port_id);
 
                let port_handle = comp_ctx.get_port_handle(port_id);
 
                let port_info = comp_ctx.get_port(port_handle);
 
                if port_info.state.is_blocked() {
 
                    self.mode = Mode::BlockedPut;
 
                    self.mode_port = port_id;
 
                    self.mode_value = value;
 
                    self.exec_ctx.stmt = ExecStmt::PerformedPut; // prepare for when we become unblocked
 
                    return Ok(CompScheduling::Sleep);
 
                } else {
 
                    self.send_data_message_and_wake_up(sched_ctx, comp_ctx, port_handle, value);
 
                    self.exec_ctx.stmt = ExecStmt::PerformedPut;
 
                    return Ok(CompScheduling::Immediate);
 
                }
 
            },
 
            EC::SelectStart(num_cases, num_ports) => {
 
            EC::SelectStart(num_cases, _num_ports) => {
 
                debug_assert_eq!(self.mode, Mode::Sync);
 
                todo!("finish handling select start")
 
                self.select.handle_select_start(num_cases);
 
                return Ok(CompScheduling::Requeue);
 
            },
 
            EC::SelectRegisterPort(case_index, port_index, port_id) => {
 
                debug_assert_eq!(self.mode, Mode::Sync);
 
                todo!("finish handling register port")
 
                let port_id = port_id_from_eval(port_id);
 
                if let Err(_err) = self.select.register_select_case_port(comp_ctx, case_index, port_index, port_id) {
 
                    todo!("handle registering a port multiple times");
 
                }
 
                return Ok(CompScheduling::Immediate);
 
            },
 
            EC::SelectWait => {
 
                debug_assert_eq!(self.mode, Mode::Sync);
 
                self.handle_select_wait(sched_ctx, comp_ctx);
 
                todo!("finish handling select wait")
 
                let select_decision = self.select.handle_select_waiting_point(&self.inbox_main, comp_ctx);
 
                if let SelectDecision::Case(case_index) = select_decision {
 
                    // Reached a conclusion, so we can continue immediately
 
                    self.exec_ctx.stmt = ExecStmt::PerformedSelectWait(case_index);
 
                    self.mode = Mode::Sync;
 
                    return Ok(CompScheduling::Immediate);
 
                } else {
 
                    // No decision yet
 
                    self.mode = Mode::BlockedSelect;
 
                    return Ok(CompScheduling::Sleep);
 
                }
 
            },
 
            // Results that can be returned outside of sync mode
 
            EC::ComponentTerminated => {
 
                self.mode = Mode::StartExit; // next call we'll take care of the exit
 
                return Ok(CompScheduling::Immediate);
 
            },
 
            EC::SyncBlockStart => {
 
                debug_assert_eq!(self.mode, Mode::NonSync);
 
                self.handle_sync_start(sched_ctx, comp_ctx);
 
                return Ok(CompScheduling::Immediate);
 
            },
 
            EC::NewComponent(definition_id, type_id, arguments) => {
 
                debug_assert_eq!(self.mode, Mode::NonSync);
 
                self.create_component_and_transfer_ports(
 
                    sched_ctx, comp_ctx,
 
                    definition_id, type_id, arguments
 
                );
 
                return Ok(CompScheduling::Requeue);
 
            },
 
            EC::NewChannel => {
 
                debug_assert_eq!(self.mode, Mode::NonSync);
 
                debug_assert!(self.exec_ctx.stmt.is_none());
 
                let channel = comp_ctx.create_channel();
 
                self.exec_ctx.stmt = ExecStmt::CreatedChannel((
 
                    Value::Output(port_id_to_eval(channel.putter_id)),
 
                    Value::Input(port_id_to_eval(channel.getter_id))
 
                ));
 
                self.inbox_main.push(None);
 
                self.inbox_main.push(None);
 
                return Ok(CompScheduling::Immediate);
 
            }
 
        }
 
    }
 

	
 
    fn execute_prompt(&mut self, sched_ctx: &SchedulerCtx) -> EvalResult {
 
        let mut step_result = EvalContinuation::Stepping;
 
        while let EvalContinuation::Stepping = step_result {
 
            step_result = self.prompt.step(
 
                &sched_ctx.runtime.protocol.types, &sched_ctx.runtime.protocol.heap,
 
                &sched_ctx.runtime.protocol.modules, &mut self.exec_ctx,
 
            )?;
 
        }
 

	
 
        return Ok(step_result)
 
    }
 

	
 
    fn handle_sync_start(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx) {
 
        sched_ctx.log("Component starting sync mode");
 
        self.consensus.notify_sync_start(comp_ctx);
 
        debug_assert_eq!(self.mode, Mode::NonSync);
 
        self.mode = Mode::Sync;
 
    }
 

	
 
    /// Handles end of sync. The conclusion to the sync round might arise
 
    /// immediately (and be handled immediately), or might come later through
 
    /// messaging. In any case the component should be scheduled again
 
    /// immediately
 
    fn handle_sync_end(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx) {
 
        sched_ctx.log("Component ending sync mode (now waiting for solution)");
 
        let decision = self.consensus.notify_sync_end(sched_ctx, comp_ctx);
 
        self.mode = Mode::SyncEnd;
 
        self.handle_sync_decision(sched_ctx, comp_ctx, decision);
 
    }
 

	
 
    /// Handles decision from the consensus round. This will cause a change in
 
    /// the internal `Mode`, such that the next call to `run` can take the
 
    /// appropriate next steps.
 
    fn handle_sync_decision(&mut self, sched_ctx: &SchedulerCtx, _comp_ctx: &mut CompCtx, decision: SyncRoundDecision) {
 
        sched_ctx.log(&format!("Handling sync decision: {:?} (in mode {:?})", decision, self.mode));
 
        let is_success = match decision {
 
            SyncRoundDecision::None => {
 
                // No decision yet
 
                return;
 
            },
 
            SyncRoundDecision::Solution => true,
 
            SyncRoundDecision::Failure => false,
 
        };
 

	
 
        // If here then we've reached a decision
 
        debug_assert_eq!(self.mode, Mode::SyncEnd);
 
        if is_success {
 
            self.mode = Mode::NonSync;
 
            self.consensus.notify_sync_decision(decision);
 
        } else {
 
            self.mode = Mode::StartExit;
 
        }
 
    }
 

	
 
    /// Handles the moment where the PDL code has notified the runtime of all
 
    /// the ports it is waiting on.
 
    fn handle_select_wait(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx) {
 
        sched_ctx.log("Component waiting for select conclusion");
 

	
 
    }
 

	
 
    fn handle_component_exit(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx) {
 
        sched_ctx.log("Component exiting");
 
        debug_assert_eq!(self.mode, Mode::StartExit);
 
        self.mode = Mode::BusyExit;
 

	
 
        // Doing this by index, then retrieving the handle is a bit rediculous,
 
        // but Rust is being Rust with its borrowing rules.
 
        for port_index in 0..comp_ctx.num_ports() {
 
            let port = comp_ctx.get_port_by_index_mut(port_index);
 
            if port.state == PortState::Closed {
 
                // Already closed, or in the process of being closed
 
                continue;
 
            }
 

	
 
            // Mark as closed
 
            let port_id = port.self_id;
 
            port.state = PortState::Closed;
 

	
 
            // Notify peer of closing
 
            let port_handle = comp_ctx.get_port_handle(port_id);
 
            let (peer, message) = self.control.initiate_port_closing(port_handle, comp_ctx);
 
            let peer_info = comp_ctx.get_peer(peer);
 
            peer_info.handle.send_message(sched_ctx, Message::Control(message), true);
 
        }
 
    }
 

	
 
    // -------------------------------------------------------------------------
 
    // Handling messages
 
    // -------------------------------------------------------------------------
 

	
 
    fn send_data_message_and_wake_up(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &CompCtx, source_port_handle: LocalPortHandle, value: ValueGroup) {
 
        let port_info = comp_ctx.get_port(source_port_handle);
 
        let peer_handle = comp_ctx.get_peer_handle(port_info.peer_comp_id);
 
        let peer_info = comp_ctx.get_peer(peer_handle);
 
        let annotated_message = self.consensus.annotate_data_message(comp_ctx, port_info, value);
 
        peer_info.handle.send_message(sched_ctx, Message::Data(annotated_message), true);
 
    }
 

	
 
    /// Handles a message that came in through the public inbox. This function
 
    /// will handle putting it in the correct place, and potentially blocking
 
    /// the port in case too many messages are being received.
 
    fn handle_incoming_data_message(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx, message: DataMessage) {
 
        // Check if we can insert it directly into the storage associated with
 
        // the port
 
        let target_port_id = message.data_header.target_port;
 
        let port_handle = comp_ctx.get_port_handle(target_port_id);
 
        let port_index = comp_ctx.get_port_index(port_handle);
 
        if self.inbox_main[port_index].is_none() {
 
            self.inbox_main[port_index] = Some(message);
 

	
 
            // After direct insertion, check if this component's execution is 
 
            // blocked on receiving a message on that port
 
            debug_assert!(!comp_ctx.get_port(port_handle).state.is_blocked()); // because we could insert directly
 
            if self.mode == Mode::BlockedGet && self.mode_port == target_port_id {
 
                // We were indeed blocked
 
                self.mode = Mode::Sync;
 
                self.mode_port = PortId::new_invalid();
 
            } else if self.mode == Mode::BlockedSelect {
 
                let select_decision = self.select.handle_updated_inbox(&self.inbox_main, comp_ctx);
 
                if let SelectDecision::Case(case_index) = select_decision {
 
                    self.exec_ctx.stmt = ExecStmt::PerformedSelectWait(case_index);
 
                    self.mode = Mode::Sync;
 
                }
 
            }
 
            
 
            return;
 
        }
 

	
 
        // The direct inbox is full, so the port will become (or was already) blocked
 
        let port_info = comp_ctx.get_port_mut(port_handle);
 
        debug_assert!(port_info.state == PortState::Open || port_info.state.is_blocked());
 

	
 
        if port_info.state == PortState::Open {
 
            comp_ctx.set_port_state(port_handle, PortState::BlockedDueToFullBuffers);
 
            let (peer_handle, message) =
 
                self.control.initiate_port_blocking(comp_ctx, port_handle);
 

	
 
            let peer = comp_ctx.get_peer(peer_handle);
 
            peer.handle.send_message(sched_ctx, Message::Control(message), true);
 
        }
 

	
 
        // But we still need to remember the message, so:
 
        self.inbox_backup.push(message);
 
    }
 

	
 
    /// Handles when a message has been handed off from the inbox to the PDL
 
    /// code. We check to see if there are more messages waiting and, if not,
 
    /// then we handle the case where the port might have been blocked
 
    /// previously.
 
    fn handle_received_data_message(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx, port_handle: LocalPortHandle) {
 
        let port_index = comp_ctx.get_port_index(port_handle);
 
        debug_assert!(self.inbox_main[port_index].is_none()); // this function should be called after the message is taken out
 

	
 
        // Check for any more messages
 
        let port_info = comp_ctx.get_port(port_handle);
 
        for message_index in 0..self.inbox_backup.len() {
 
            let message = &self.inbox_backup[message_index];
 
            if message.data_header.target_port == port_info.self_id {
 
                // One more message for this port
 
                let message = self.inbox_backup.remove(message_index);
 
                debug_assert!(comp_ctx.get_port(port_handle).state.is_blocked()); // since we had >1 message on the port
 
                self.inbox_main[port_index] = Some(message);
 

	
 
                return;
 
            }
 
        }
 

	
 
        // Did not have any more messages. So if we were blocked, then we need
 
        // to send the "unblock" message.
 
        if port_info.state == PortState::BlockedDueToFullBuffers {
 
            comp_ctx.set_port_state(port_handle, PortState::Open);
 
            let (peer_handle, message) = self.control.cancel_port_blocking(comp_ctx, port_handle);
 
            let peer_info = comp_ctx.get_peer(peer_handle);
 
            peer_info.handle.send_message(sched_ctx, Message::Control(message), true);
 
        }
 
    }
 

	
 
    fn handle_incoming_control_message(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx, message: ControlMessage) {
 
        // Little local utility to send an Ack
 
        fn send_control_ack_message(sched_ctx: &SchedulerCtx, comp_ctx: &CompCtx, causer_id: ControlId, peer_handle: LocalPeerHandle) {
 
            let peer_info = comp_ctx.get_peer(peer_handle);
 
            peer_info.handle.send_message(sched_ctx, Message::Control(ControlMessage{
 
                id: causer_id,
 
                sender_comp_id: comp_ctx.id,
 
                target_port_id: None,
 
                content: ControlMessageContent::Ack,
 
            }), true);
 
        }
 

	
 
        // Handle the content of the control message, and optionally Ack it
 
        match message.content {
 
            ControlMessageContent::Ack => {
 
                self.handle_ack(sched_ctx, comp_ctx, message.id);
 
            },
 
            ControlMessageContent::BlockPort(port_id) => {
 
                // On of our messages was accepted, but the port should be
 
                // blocked.
 
                let port_handle = comp_ctx.get_port_handle(port_id);
 
                let port_info = comp_ctx.get_port(port_handle);
 
                debug_assert_eq!(port_info.kind, PortKind::Putter);
 
                if port_info.state == PortState::Open {
 
                    // only when open: we don't do this when closed, and we we don't do this if we're blocked due to peer changes
 
                    comp_ctx.set_port_state(port_handle, PortState::BlockedDueToFullBuffers);
 
                }
 
            },
 
            ControlMessageContent::ClosePort(port_id) => {
 
                // Request to close the port. We immediately comply and remove
 
                // the component handle as well
 
                let port_handle = comp_ctx.get_port_handle(port_id);
 
                let peer_comp_id = comp_ctx.get_port(port_handle).peer_comp_id;
 
                let peer_handle = comp_ctx.get_peer_handle(peer_comp_id);
 

	
 
                // One exception to sending an `Ack` is if we just closed the
 
                // port ourselves, meaning that the `ClosePort` messages got
 
                // sent to one another.
 
                if let Some(control_id) = self.control.has_close_port_entry(port_handle, comp_ctx) {
 
                    self.handle_ack(sched_ctx, comp_ctx, control_id);
 
                } else {
 
                    send_control_ack_message(sched_ctx, comp_ctx, message.id, peer_handle);
 
                    comp_ctx.remove_peer(sched_ctx, port_handle, peer_comp_id, false); // do not remove if closed
 
                    comp_ctx.set_port_state(port_handle, PortState::Closed); // now set to closed
 
                }
 
            },
 
            ControlMessageContent::UnblockPort(port_id) => {
 
                // We were previously blocked (or already closed)
 
                let port_handle = comp_ctx.get_port_handle(port_id);
 
                let port_info = comp_ctx.get_port(port_handle);
 
                debug_assert_eq!(port_info.kind, PortKind::Putter);
 
                if port_info.state == PortState::BlockedDueToFullBuffers {
 
                    self.handle_unblock_port_instruction(sched_ctx, comp_ctx, port_handle);
 
                }
 
            },
 
            ControlMessageContent::PortPeerChangedBlock(port_id) => {
 
                // The peer of our port has just changed. So we are asked to
 
                // temporarily block the port (while our original recipient is
 
                // potentially rerouting some of the in-flight messages) and
 
                // Ack. Then we wait for the `unblock` call.
 
                debug_assert_eq!(message.target_port_id, Some(port_id));
 
                let port_handle = comp_ctx.get_port_handle(port_id);
 
                comp_ctx.set_port_state(port_handle, PortState::BlockedDueToPeerChange);
 

	
 
                let port_info = comp_ctx.get_port(port_handle);
 
                let peer_handle = comp_ctx.get_peer_handle(port_info.peer_comp_id);
 

	
 
                send_control_ack_message(sched_ctx, comp_ctx, message.id, peer_handle);
 
            },
 
            ControlMessageContent::PortPeerChangedUnblock(new_port_id, new_comp_id) => {
 
                let port_handle = comp_ctx.get_port_handle(message.target_port_id.unwrap());
 
                let port_info = comp_ctx.get_port(port_handle);
 
                debug_assert!(port_info.state == PortState::BlockedDueToPeerChange);
 
                let old_peer_id = port_info.peer_comp_id;
 

	
 
                comp_ctx.remove_peer(sched_ctx, port_handle, old_peer_id, false);
 

	
 
                let port_info = comp_ctx.get_port_mut(port_handle);
 
                port_info.peer_comp_id = new_comp_id;
 
                port_info.peer_port_id = new_port_id;
 
                comp_ctx.add_peer(port_handle, sched_ctx, new_comp_id, None);
 
                self.handle_unblock_port_instruction(sched_ctx, comp_ctx, port_handle);
 
            }
 
        }
 
    }
 

	
 
    fn handle_incoming_sync_message(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx, message: SyncMessage) {
 
        let decision = self.consensus.receive_sync_message(sched_ctx, comp_ctx, message);
 
        self.handle_sync_decision(sched_ctx, comp_ctx, decision);
 
    }
 

	
 
    /// Little helper that notifies the control layer of an `Ack`, and takes the
 
    /// appropriate subsequent action
 
    fn handle_ack(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx, control_id: ControlId) {
 
        let mut to_ack = control_id;
 
        loop {
 
            let (action, new_to_ack) = self.control.handle_ack(to_ack, sched_ctx, comp_ctx);
 
            match action {
 
                AckAction::SendMessage(target_comp, message) => {
 
                    // FIX @NoDirectHandle
 
                    let mut handle = sched_ctx.runtime.get_component_public(target_comp);
 
                    handle.send_message(sched_ctx, Message::Control(message), true);
 
                    let _should_remove = handle.decrement_users();
 
                    debug_assert!(_should_remove.is_none());
 
                },
 
                AckAction::ScheduleComponent(to_schedule) => {
 
                    // FIX @NoDirectHandle
 
                    let mut handle = sched_ctx.runtime.get_component_public(to_schedule);
 

	
 
                    // Note that the component is intentionally not
 
                    // sleeping, so we just wake it up
 
                    debug_assert!(!handle.sleeping.load(std::sync::atomic::Ordering::Acquire));
 
                    let key = unsafe{ to_schedule.upgrade() };
 
                    sched_ctx.runtime.enqueue_work(key);
 
                    let _should_remove = handle.decrement_users();
 
                    debug_assert!(_should_remove.is_none());
 
                },
 
                AckAction::None => {}
 
            }
 

	
 
            match new_to_ack {
 
                Some(new_to_ack) => to_ack = new_to_ack,
 
                None => break,
 
            }
 
        }
 
    }
 

	
 
    // -------------------------------------------------------------------------
 
    // Handling ports
 
    // -------------------------------------------------------------------------
 

	
 
    /// Unblocks a port, potentially continuing execution of the component, in
 
    /// response to a message that told us to unblock a previously blocked
 
    fn handle_unblock_port_instruction(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx, port_handle: LocalPortHandle) {
 
        let port_info = comp_ctx.get_port_mut(port_handle);
 
        let port_id = port_info.self_id;
 
        debug_assert!(port_info.state.is_blocked());
 
        port_info.state = PortState::Open;
src/runtime2/component/consensus.rs
Show inline comments
 
@@ -235,380 +235,393 @@ impl SolutionCombiner {
 
            }
 
        }
 
    }
 
}
 

	
 
/// Tracking consensus state
 
pub struct Consensus {
 
    // General state of consensus manager
 
    mapping_counter: u32,
 
    mode: Mode,
 
    // State associated with sync round
 
    round_index: u32,
 
    highest_id: CompId,
 
    ports: Vec<PortAnnotation>,
 
    // State associated with arriving at a solution and being a (temporary)
 
    // leader in the consensus round
 
    solution: SolutionCombiner,
 
}
 

	
 
impl Consensus {
 
    pub(crate) fn new() -> Self {
 
        return Self{
 
            round_index: 0,
 
            highest_id: CompId::new_invalid(),
 
            ports: Vec::new(),
 
            mapping_counter: 0,
 
            mode: Mode::NonSync,
 
            solution: SolutionCombiner::new(),
 
        }
 
    }
 

	
 
    // -------------------------------------------------------------------------
 
    // Managing sync state
 
    // -------------------------------------------------------------------------
 

	
 
    /// Notifies the consensus management that the PDL code has reached the
 
    /// start of a sync block.
 
    pub(crate) fn notify_sync_start(&mut self, comp_ctx: &CompCtx) {
 
        debug_assert_eq!(self.mode, Mode::NonSync);
 
        self.highest_id = comp_ctx.id;
 
        self.mapping_counter = 0;
 
        self.mode = Mode::SyncBusy;
 
        self.make_ports_consistent_with_ctx(comp_ctx);
 
    }
 

	
 
    /// Notifies the consensus management that the PDL code has reached the end
 
    /// of a sync block. A local solution will be submitted, after which we wait
 
    /// until the participants in the round (hopefully) reach a conclusion.
 
    pub(crate) fn notify_sync_end(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &CompCtx) -> SyncRoundDecision {
 
        debug_assert_eq!(self.mode, Mode::SyncBusy);
 
        self.mode = Mode::SyncAwaitingSolution;
 

	
 
        // Submit our port mapping as a solution
 
        let mut local_solution = Vec::with_capacity(self.ports.len());
 
        for port in &self.ports {
 
            if let Some(mapping) = port.mapping {
 
                let port_handle = comp_ctx.get_port_handle(port.self_port_id);
 
                let port_info = comp_ctx.get_port(port_handle);
 
                let new_entry = match port_info.kind {
 
                    PortKind::Putter => SyncLocalSolutionEntry::Putter(SyncSolutionPutterPort{
 
                        self_comp_id: comp_ctx.id,
 
                        self_port_id: port_info.self_id,
 
                        mapping
 
                    }),
 
                    PortKind::Getter => SyncLocalSolutionEntry::Getter(SyncSolutionGetterPort{
 
                        self_comp_id: comp_ctx.id,
 
                        self_port_id: port_info.self_id,
 
                        peer_comp_id: port.peer_comp_id,
 
                        peer_port_id: port.peer_port_id,
 
                        mapping
 
                    })
 
                };
 
                local_solution.push(new_entry);
 
            }
 
        }
 

	
 
        let decision = self.handle_local_solution(sched_ctx, comp_ctx, comp_ctx.id, local_solution);
 
        return decision;
 
    }
 

	
 
    /// Notifies that a decision has been reached. Note that the caller should
 
    /// still take the appropriate actions based on the decision it is supplying
 
    /// to the consensus layer.
 
    pub(crate) fn notify_sync_decision(&mut self, _decision: SyncRoundDecision) {
 
        // Reset everything for the next round
 
        debug_assert_eq!(self.mode, Mode::SyncAwaitingSolution);
 
        self.mode = Mode::NonSync;
 
        self.round_index = self.round_index.wrapping_add(1);
 

	
 
        for port in self.ports.iter_mut() {
 
            port.mapping = None;
 
        }
 

	
 
        self.solution.clear();
 
    }
 

	
 
    fn make_ports_consistent_with_ctx(&mut self, comp_ctx: &CompCtx) {
 
        let mut needs_setting_ports = false;
 
        if comp_ctx.num_ports() != self.ports.len() {
 
            needs_setting_ports = true;
 
        } else {
 
            for (idx, port) in comp_ctx.iter_ports().enumerate() {
 
                let comp_port_id = port.self_id;
 
                let cons_port_id = self.ports[idx].self_port_id;
 
                if comp_port_id != cons_port_id {
 
                    needs_setting_ports = true;
 
                    break;
 
                }
 
            }
 
        }
 

	
 
        if needs_setting_ports {
 
            self.ports.clear();
 
            self.ports.reserve(comp_ctx.num_ports());
 
            for port in comp_ctx.iter_ports() {
 
                self.ports.push(PortAnnotation::new(comp_ctx.id, port.self_id))
 
            }
 
        }
 
    }
 

	
 
    // -------------------------------------------------------------------------
 
    // Handling inbound and outbound messages
 
    // -------------------------------------------------------------------------
 

	
 
    pub(crate) fn annotate_data_message(&mut self, comp_ctx: &CompCtx, port_info: &Port, content: ValueGroup) -> DataMessage {
 
        debug_assert_eq!(self.mode, Mode::SyncBusy); // can only send between sync start and sync end
 
        debug_assert!(self.ports.iter().any(|v| v.self_port_id == port_info.self_id));
 
        let data_header = self.create_data_header_and_update_mapping(port_info);
 
        let sync_header = self.create_sync_header(comp_ctx);
 

	
 
        return DataMessage{ data_header, sync_header, content };
 
    }
 

	
 
    /// Checks if the data message can be received (due to port annotations), if
 
    /// it can then `true` is returned and the caller is responsible for handing
 
    /// the message of to the PDL code. Otherwise the message cannot be
 
    /// received.
 
    pub(crate) fn try_receive_data_message(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx, message: &DataMessage) -> bool {
 
        debug_assert_eq!(self.mode, Mode::SyncBusy);
 
        debug_assert!(self.ports.iter().any(|v| v.self_port_id == message.data_header.target_port));
 

	
 
        // Make sure the expected mapping matches the currently stored mapping
 
        for (expected_id, expected_annotation) in &message.data_header.expected_mapping {
 
            let got_annotation = self.get_annotation(*expected_id);
 
            if got_annotation != *expected_annotation {
 
                return false;
 
            }
 
        }
 

	
 
        // Expected mapping matches current mapping, so we will receive the message
 
        self.set_annotation(message.sync_header.sending_id, &message.data_header);
 

	
 
        // Handle the sync header embedded within the data message
 
        self.handle_sync_header(sched_ctx, comp_ctx, &message.sync_header);
 

	
 
        return true;
 
    }
 

	
 
    /// Receives the sync message and updates the consensus state appropriately.
 
    pub(crate) fn receive_sync_message(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx, message: SyncMessage) -> SyncRoundDecision {
 
        // Whatever happens: handle the sync header (possibly changing the
 
        // currently registered leader)
 
        self.handle_sync_header(sched_ctx, comp_ctx, &message.sync_header);
 

	
 
        match message.content {
 
            SyncMessageContent::NotificationOfLeader => {
 
                return SyncRoundDecision::None;
 
            },
 
            SyncMessageContent::LocalSolution(solution_generator_id, local_solution) => {
 
                return self.handle_local_solution(sched_ctx, comp_ctx, solution_generator_id, local_solution);
 
            },
 
            SyncMessageContent::PartialSolution(partial_solution) => {
 
                return self.handle_partial_solution(sched_ctx, comp_ctx, partial_solution);
 
            },
 
            SyncMessageContent::GlobalSolution => {
 
                debug_assert_eq!(self.mode, Mode::SyncAwaitingSolution); // leader can only find global- if we submitted local solution
 
                return SyncRoundDecision::Solution;
 
            },
 
            SyncMessageContent::GlobalFailure => {
 
                debug_assert_eq!(self.mode, Mode::SyncAwaitingSolution);
 
                return SyncRoundDecision::Failure;
 
            }
 
        }
 
    }
 

	
 
    fn handle_sync_header(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx, header: &MessageSyncHeader) {
 
        if header.highest_id.0 > self.highest_id.0 {
 
            // Sender knows of someone with a higher ID. So store highest ID,
 
            // notify all peers, and forward local solutions
 
            self.highest_id = header.highest_id;
 
            for peer in comp_ctx.iter_peers() {
 
                if peer.id == header.sending_id {
 
                    continue; // do not send to sender: it has the higher ID
 
                }
 

	
 
                // also: only send if we received a message in this round
 
                let mut performed_communication = false; // TODO: Revise, temporary fix
 
                for port in self.ports.iter() {
 
                    if port.peer_comp_id == peer.id && port.mapping.is_some() {
 
                        performed_communication = true;
 
                        break;
 
                    }
 
                }
 

	
 
                if !performed_communication {
 
                    continue;
 
                }
 

	
 
                let message = SyncMessage{
 
                    sync_header: self.create_sync_header(comp_ctx),
 
                    content: SyncMessageContent::NotificationOfLeader,
 
                };
 
                peer.handle.send_message(sched_ctx, Message::Sync(message), true);
 
            }
 

	
 
            self.forward_partial_solution(sched_ctx, comp_ctx);
 
        } else if header.highest_id.0 < self.highest_id.0 {
 
            // Sender has a lower ID, so notify it of our higher one
 
            let message = SyncMessage{
 
                sync_header: self.create_sync_header(comp_ctx),
 
                content: SyncMessageContent::NotificationOfLeader,
 
            };
 
            let peer_handle = comp_ctx.get_peer_handle(header.sending_id);
 
            let peer_info = comp_ctx.get_peer(peer_handle);
 
            peer_info.handle.send_message(sched_ctx, Message::Sync(message), true);
 
        } // else: exactly equal
 
    }
 

	
 
    fn get_annotation(&self, port_id: PortId) -> Option<u32> {
 
        for annotation in self.ports.iter() {
 
            if annotation.self_port_id == port_id {
 
                return annotation.mapping;
 
            }
 
        }
 

	
 
        debug_assert!(false);
 
        return None;
 
    }
 

	
 
    fn set_annotation(&mut self, source_comp_id: CompId, data_header: &MessageDataHeader) {
 
        for annotation in self.ports.iter_mut() {
 
            if annotation.self_port_id == data_header.target_port {
 
                annotation.peer_comp_id = source_comp_id;
 
                annotation.peer_port_id = data_header.source_port;
 
                annotation.mapping = Some(data_header.new_mapping);
 
            }
 
        }
 
    }
 

	
 
    // -------------------------------------------------------------------------
 
    // Leader-related methods
 
    // -------------------------------------------------------------------------
 

	
 
    fn forward_partial_solution(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx) {
 
        debug_assert_ne!(self.highest_id, comp_ctx.id); // not leader
 

	
 
        // Make sure that we have something to send
 
        if !self.solution.has_contributions() {
 
            return;
 
        }
 

	
 
        // Swap the container with the partial solution and then send it along
 
        let partial_solution = self.solution.take_partial_solution();
 
        self.send_to_leader(sched_ctx, comp_ctx, Message::Sync(SyncMessage{
 
            sync_header: self.create_sync_header(comp_ctx),
 
            content: SyncMessageContent::PartialSolution(partial_solution),
 
        }));
 
    }
 

	
 
    fn handle_local_solution(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &CompCtx, solution_sender_id: CompId, solution: SyncLocalSolution) -> SyncRoundDecision {
 
        if self.highest_id == comp_ctx.id {
 
            // We are the leader
 
            self.solution.combine_with_local_solution(solution_sender_id, solution);
 
            let round_decision = self.solution.get_decision();
 
            if round_decision != SyncRoundDecision::None {
 
                self.broadcast_decision(sched_ctx, comp_ctx, round_decision);
 
            }
 
            return round_decision;
 
        } else {
 
            // Forward the solution
 
            let message = SyncMessage{
 
                sync_header: self.create_sync_header(comp_ctx),
 
                content: SyncMessageContent::LocalSolution(solution_sender_id, solution),
 
            };
 
            self.send_to_leader(sched_ctx, comp_ctx, Message::Sync(message));
 
            return SyncRoundDecision::None;
 
        }
 
    }
 

	
 
    fn handle_partial_solution(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx, solution: SyncPartialSolution) -> SyncRoundDecision {
 
        if self.highest_id == comp_ctx.id {
 
            // We are the leader, combine existing and new solution
 
            self.solution.combine_with_partial_solution(solution);
 
            let round_decision = self.solution.get_decision();
 
            if round_decision != SyncRoundDecision::None {
 
                self.broadcast_decision(sched_ctx, comp_ctx, round_decision);
 
            }
 
            return round_decision;
 
        } else {
 
            // Forward the partial solution
 
            let message = SyncMessage{
 
                sync_header: self.create_sync_header(comp_ctx),
 
                content: SyncMessageContent::PartialSolution(solution),
 
            };
 
            self.send_to_leader(sched_ctx, comp_ctx, Message::Sync(message));
 
            return SyncRoundDecision::None;
 
        }
 
    }
 

	
 
    fn broadcast_decision(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &CompCtx, decision: SyncRoundDecision) {
 
        debug_assert_eq!(self.highest_id, comp_ctx.id);
 

	
 
        let is_success = match decision {
 
            SyncRoundDecision::None => unreachable!(),
 
            SyncRoundDecision::Solution => true,
 
            SyncRoundDecision::Failure => false,
 
        };
 

	
 
        let mut peers = Vec::with_capacity(self.solution.solution.channel_mapping.len()); // TODO: @Performance
 

	
 
        for channel in self.solution.solution.channel_mapping.iter() {
 
            let getter = channel.getter.as_ref().unwrap();
 
            if getter.self_comp_id != comp_ctx.id && !peers.contains(&getter.self_comp_id) {
 
                peers.push(getter.self_comp_id);
 
            }
 
            if getter.peer_comp_id != comp_ctx.id && !peers.contains(&getter.peer_comp_id) {
 
                peers.push(getter.peer_comp_id);
 
            }
 
        }
 

	
 
        for peer in peers {
 
            let mut handle = sched_ctx.runtime.get_component_public(peer);
 
            let message = Message::Sync(SyncMessage{
 
                sync_header: self.create_sync_header(comp_ctx),
 
                content: if is_success { SyncMessageContent::GlobalSolution } else { SyncMessageContent::GlobalFailure },
 
            });
 
            handle.send_message(sched_ctx, message, true);
 
            let _should_remove = handle.decrement_users();
 
            debug_assert!(_should_remove.is_none());
 
        }
 
    }
 

	
 
    fn send_to_leader(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &CompCtx, message: Message) {
 
        debug_assert_ne!(self.highest_id, comp_ctx.id); // we're not the leader
 
        let mut leader_info = sched_ctx.runtime.get_component_public(self.highest_id);
 
        leader_info.send_message(sched_ctx, message, true);
 
        let should_remove = leader_info.decrement_users();
 
        if let Some(key) = should_remove {
 
            sched_ctx.runtime.destroy_component(key);
 
        }
 
    }
 

	
 
    // -------------------------------------------------------------------------
 
    // Creating message headers
 
    // -------------------------------------------------------------------------
 

	
 
    fn create_data_header_and_update_mapping(&mut self, port_info: &Port) -> MessageDataHeader {
 
        let mut expected_mapping = Vec::with_capacity(self.ports.len());
 
        let mut port_index = usize::MAX;
 
        for (index, port) in self.ports.iter().enumerate() {
 
            if port.self_port_id == port_info.self_id {
 
                port_index = index;
 
            }
 
            expected_mapping.push((port.self_port_id, port.mapping));
 
        }
 

	
 
        let new_mapping = self.take_mapping();
 
        self.ports[port_index].mapping = Some(new_mapping);
 
        debug_assert_eq!(port_info.kind, PortKind::Putter);
 
        return MessageDataHeader{
 
            expected_mapping,
 
            new_mapping,
 
            source_port: port_info.self_id,
 
            target_port: port_info.peer_port_id,
 
        };
 
    }
 

	
 
    #[inline]
 
    fn create_sync_header(&self, comp_ctx: &CompCtx) -> MessageSyncHeader {
 
        return MessageSyncHeader{
 
            sync_round: self.round_index,
 
            sending_id: comp_ctx.id,
 
            highest_id: self.highest_id,
 
        };
 
    }
 

	
 
    #[inline]
 
    fn take_mapping(&mut self) -> u32 {
 
        let mapping = self.mapping_counter;
 
        self.mapping_counter = self.mapping_counter.wrapping_add(1);
 
        return mapping;
 
    }
 
}
 
\ No newline at end of file
src/runtime2/tests/mod.rs
Show inline comments
 
use crate::protocol::*;
 
use crate::protocol::eval::*;
 
use crate::runtime2::runtime::*;
 
use crate::runtime2::component::{CompCtx, CompPDL};
 

	
 
fn create_component(rt: &Runtime, module_name: &str, routine_name: &str, args: ValueGroup) {
 
    let prompt = rt.inner.protocol.new_component(
 
        module_name.as_bytes(), routine_name.as_bytes(), args
 
    ).expect("create prompt");
 
    let reserved = rt.inner.start_create_pdl_component();
 
    let ctx = CompCtx::new(&reserved);
 
    let (key, _) = rt.inner.finish_create_pdl_component(reserved, CompPDL::new(prompt, 0), ctx, false);
 
    rt.inner.enqueue_work(key);
 
}
 

	
 
fn no_args() -> ValueGroup { ValueGroup::new_stack(Vec::new()) }
 

	
 
#[test]
 
fn test_component_creation() {
 
    let pd = ProtocolDescription::parse(b"
 
    primitive nothing_at_all() {
 
        s32 a = 5;
 
        auto b = 5 + a;
 
    }
 
    ").expect("compilation");
 
    let rt = Runtime::new(1, true, pd);
 

	
 
    for _i in 0..20 {
 
        create_component(&rt, "", "nothing_at_all", no_args());
 
    }
 
}
 

	
 
#[test]
 
fn test_component_communication() {
 
    let pd = ProtocolDescription::parse(b"
 
    primitive sender(out<u32> o, u32 outside_loops, u32 inside_loops) {
 
        u32 outside_index = 0;
 
        while (outside_index < outside_loops) {
 
            u32 inside_index = 0;
 
            sync while (inside_index < inside_loops) {
 
                put(o, inside_index);
 
                inside_index += 1;
 
            }
 
            outside_index += 1;
 
        }
 
    }
 

	
 
    primitive receiver(in<u32> i, u32 outside_loops, u32 inside_loops) {
 
        u32 outside_index = 0;
 
        while (outside_index < outside_loops) {
 
            u32 inside_index = 0;
 
            sync while (inside_index < inside_loops) {
 
                auto val = get(i);
 
                while (val != inside_index) {} // infinite loop if incorrect value is received
 
                inside_index += 1;
 
            }
 
            outside_index += 1;
 
        }
 
    }
 

	
 
    composite constructor() {
 
        channel o_orom -> i_orom;
 
        channel o_mrom -> i_mrom;
 
        channel o_ormm -> i_ormm;
 
        channel o_mrmm -> i_mrmm;
 

	
 
        // one round, one message per round
 
        new sender(o_orom, 1, 1);
 
        new receiver(i_orom, 1, 1);
 

	
 
        // multiple rounds, one message per round
 
        new sender(o_mrom, 5, 1);
 
        new receiver(i_mrom, 5, 1);
 

	
 
        // one round, multiple messages per round
 
        new sender(o_ormm, 1, 5);
 
        new receiver(i_ormm, 1, 5);
 

	
 
        // multiple rounds, multiple messages per round
 
        new sender(o_mrmm, 5, 5);
 
        new receiver(i_mrmm, 5, 5);
 
    }").expect("compilation");
 
    let rt = Runtime::new(3, true, pd);
 
    create_component(&rt, "", "constructor", no_args());
 
}
 

	
 
#[test]
 
fn test_simple_select() {
 
    let pd = ProtocolDescription::parse(b"
 
    func infinite_assert<T>(T val, T expected) -> () {
 
        while (val != expected) { print(\"nope!\"); }
 
        return ();
 
    }
 

	
 
    primitive receiver(in<u32> in_a, in<u32> in_b, u32 num_sends) {
 
        auto num_from_a = 0;
 
        auto num_from_b = 0;
 
        while (num_from_a + num_from_b < 2 * num_sends) {
 
            sync select {
 
                auto v = get(in_a) -> {
 
                    print(\"got something from A\");
 
                    infinite_assert(v, num_from_a);
 
                    auto _ = infinite_assert(v, num_from_a);
 
                    num_from_a += 1;
 
                }
 
                auto v = get(in_b) -> {
 
                    print(\"got something from B\");
 
                    infinite_assert(v, num_from_b);
 
                    num_from_b +=1;
 
                    auto _ = infinite_assert(v, num_from_b);
 
                    num_from_b += 1;
 
                }
 
            }
 
        }
 
    }
 

	
 
    primitive sender(out<u32> tx, u32 num_sends) {
 
        auto index = 0;
 
        while (index < num_sends) {
 
            sync {
 
                put(tx, index);
 
                index += 1;
 
            }
 
        }
 
    }
 

	
 
    composite constructor() {
 
        auto num_sends = 3;
 
        auto num_sends = 15;
 
        channel tx_a -> rx_a;
 
        channel tx_b -> rx_b;
 
        new sender(tx_a, num_sends);
 
        new receiver(rx_a, rx_b, num_sends);
 
        new sender(tx_b, num_sends);
 
    }
 
    ").expect("compilation");
 
    let rt = Runtime::new(1, true, pd);
 
    let rt = Runtime::new(3, false, pd);
 
    create_component(&rt, "", "constructor", no_args());
 
}
 
\ No newline at end of file
0 comments (0 inline, 0 general)