diff --git a/Cargo.toml b/Cargo.toml index 7e7090b3c3a0324c9fa86cc906390a012002fe1e..9c3801265cc09b5ee9216ea9d148e4ce62661815 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,7 @@ lazy_static = "1.4.0" libc = { version = "^0.2", optional = true } os_socketaddr = { version = "0.1.0", optional = true } -[dev-dependencies] +# randomness rand = "0.8.4" rand_pcg = "0.3.1" diff --git a/docs/runtime/sync.md b/docs/runtime/sync.md new file mode 100644 index 0000000000000000000000000000000000000000..ac3c9fcbe4b40abdf95617b092af5a2f8a1ccd1d --- /dev/null +++ b/docs/runtime/sync.md @@ -0,0 +1,3 @@ +# Synchronous Communication + +## \ No newline at end of file diff --git a/src/collections/raw_vec.rs b/src/collections/raw_vec.rs index c1b4806d59ab89348ef4c6225d1660cf1a1291aa..91ecb47b9c7f79f7d79db83ed638628104da6a22 100644 --- a/src/collections/raw_vec.rs +++ b/src/collections/raw_vec.rs @@ -62,20 +62,6 @@ impl RawVec { } } - /// Moves the elements in the range [from_idx, from_idx + num_to_move) to - /// the range [to_idx, to_idx + num_to_move). Caller must make sure that all - /// non-overlapping elements of the second range had their destructor called - /// in case those elements were used. - pub fn move_range(&mut self, from_idx: usize, to_idx: usize, num_to_move: usize) { - debug_assert!(from_idx + num_to_move <= self.len); - debug_assert!(to_idx + num_to_move <= self.len); // maybe not in future, for now this is fine - unsafe { - let source = self.base.add(from_idx); - let target = self.base.add(to_idx); - std::ptr::copy(source, target, num_to_move); - } - } - pub fn len(&self) -> usize { return self.len; } @@ -140,9 +126,7 @@ impl Drop for RawVec { let (_, layout) = self.current_layout(); unsafe { dealloc(self.base as *mut u8, layout); - if cfg!(debug_assertions) { - self.base = ptr::null_mut(); - } + dbg_code!({ self.base = ptr::null_mut(); }); } } } diff --git a/src/collections/scoped_buffer.rs b/src/collections/scoped_buffer.rs index 7f5194cf31bc6ef317490d82f6dc413fb07e2272..da789984728f3c4796ea46d3715fc91395ab10fd 100644 --- a/src/collections/scoped_buffer.rs +++ b/src/collections/scoped_buffer.rs @@ -73,54 +73,80 @@ pub(crate) struct ScopedSection { } impl ScopedSection { + /// Pushes value into section #[inline] pub(crate) fn push(&mut self, value: T) { + self.check_length(); let vec = unsafe{&mut *self.inner}; - hide!(debug_assert_eq!( - vec.len(), self.cur_size as usize, - "trying to push onto section, but size is larger than expected" - )); vec.push(value); hide!(self.cur_size += 1); } #[inline] pub(crate) fn len(&self) -> usize { + self.check_length(); let vec = unsafe{&mut *self.inner}; - hide!(debug_assert_eq!( - vec.len(), self.cur_size as usize, - "trying to get section length, but size is larger than expected" - )); return vec.len() - self.start_size as usize; } #[inline] #[allow(unused_mut)] // used in debug mode pub(crate) fn forget(mut self) { + self.check_length(); let vec = unsafe{&mut *self.inner}; - hide!({ - debug_assert_eq!( - vec.len(), self.cur_size as usize, - "trying to forget section, but size is larger than expected" - ); - self.cur_size = self.start_size; - }); + hide!(self.cur_size = self.start_size); vec.truncate(self.start_size as usize); } #[inline] #[allow(unused_mut)] // used in debug mode pub(crate) fn into_vec(mut self) -> Vec { + self.check_length(); let vec = unsafe{&mut *self.inner}; + hide!(self.cur_size = self.start_size); + let section = Vec::from_iter(vec.drain(self.start_size as usize..)); + section + } + + #[inline] + pub(crate) fn check_length(&self) { hide!({ + let vec = unsafe{&*self.inner}; debug_assert_eq!( vec.len(), self.cur_size as usize, - "trying to turn section into vec, but size is larger than expected" - ); - self.cur_size = self.start_size; - }); - let section = Vec::from_iter(vec.drain(self.start_size as usize..)); - section + "incorrect use of ScopedSection: underlying storage vector has changed size" + ) + }) + } +} + +impl ScopedSection { + #[inline] + pub(crate) fn push_unique(&mut self, value: T) { + self.check_length(); + let vec = unsafe{&mut *self.inner}; + for item in &vec[self.start_size as usize..] { + if *item == value { + // item already exists + return; + } + } + + vec.push(value); + hide!(self.cur_size += 1); + } + + #[inline] + pub(crate) fn contains(&self, value: &T) -> bool { + self.check_length(); + let vec = unsafe{&*self.inner}; + for index in self.start_size as usize..vec.len() { + if &vec[index] == value { + return true; + } + } + + return false; } } diff --git a/src/collections/string_pool.rs b/src/collections/string_pool.rs index 5a31cb902ff697833f222b25997d0e90574810b4..58fc38bd5524f9abf76135fc40b9d3f4066d414f 100644 --- a/src/collections/string_pool.rs +++ b/src/collections/string_pool.rs @@ -1,4 +1,4 @@ -use std::ptr::null_mut; +use std::ptr::{null_mut, null}; use std::hash::{Hash, Hasher}; use std::marker::PhantomData; use std::fmt::{Debug, Display, Formatter, Result as FmtResult}; @@ -29,6 +29,12 @@ impl<'a> StringRef<'a> { StringRef{ data, length, _phantom: PhantomData } } + /// `new_empty` creates a empty StringRef. It is a null pointer with a + /// length of zero. + pub(crate) const fn new_empty() -> StringRef<'static> { + StringRef{ data: null(), length: 0, _phantom: PhantomData } + } + pub fn as_str(&self) -> &'a str { unsafe { let slice = std::slice::from_raw_parts::<'a, u8>(self.data, self.length); @@ -161,6 +167,13 @@ unsafe impl Send for StringPool {} mod tests { use super::*; + #[test] + fn display_empty_string_ref() { + // Makes sure that null pointer inside StringRef will not cause issues + let v = StringRef::new_empty(); + let _val = format!("{}{:?}", v, v); // calls Format and Debug on StringRef + } + #[test] fn test_string_just_fits() { let large = "0".repeat(SLAB_SIZE); diff --git a/src/common.rs b/src/common.rs deleted file mode 100644 index 08587130bf657da1b637756355705f7d25c40ad5..0000000000000000000000000000000000000000 --- a/src/common.rs +++ /dev/null @@ -1,245 +0,0 @@ -///////////////////// PRELUDE ///////////////////// -pub(crate) use crate::protocol::{ComponentState, ProtocolDescription}; -pub(crate) use crate::runtime::{error::AddComponentError, NonsyncProtoContext, SyncProtoContext}; -pub(crate) use core::{ - cmp::Ordering, - fmt::{Debug, Formatter}, - hash::Hash, - ops::Range, - time::Duration, -}; -pub(crate) use maplit::hashmap; -pub(crate) use mio::{ - net::{TcpListener, TcpStream}, - Events, Interest, Poll, Token, -}; -pub(crate) use std::{ - collections::{BTreeMap, HashMap, HashSet}, - io::{Read, Write}, - net::SocketAddr, - sync::Arc, - time::Instant, -}; -pub(crate) use Polarity::*; - -pub(crate) trait IdParts { - fn id_parts(self) -> (ConnectorId, U32Suffix); -} - -/// Used by various distributed algorithms to identify connectors. -pub type ConnectorId = u32; - -/// Used in conjunction with the `ConnectorId` type to create identifiers for ports and components -pub type U32Suffix = u32; -#[derive(Copy, Clone, Eq, PartialEq, Ord, Hash, PartialOrd)] - -/// Generalization of a port/component identifier -#[derive(serde::Serialize, serde::Deserialize)] -#[repr(C)] -pub struct Id { - pub(crate) connector_id: ConnectorId, - pub(crate) u32_suffix: U32Suffix, -} -#[derive(Clone, Debug, Default)] -pub struct U32Stream { - next: u32, -} - -/// Identifier of a component in a session -#[derive(Copy, Clone, Eq, PartialEq, Ord, Hash, PartialOrd, serde::Serialize, serde::Deserialize)] -pub struct ComponentId(Id); // PUB because it can be returned by errors - -/// Identifier of a port in a session -#[derive(Copy, Clone, Eq, PartialEq, Ord, Hash, PartialOrd, serde::Serialize, serde::Deserialize)] -#[repr(transparent)] -pub struct PortId(pub(crate) Id); - -impl PortId { - // TODO: Remove concept of ComponentId and PortId in this file - #[deprecated] - pub fn new(port: u32) -> Self { - return PortId(Id{ - connector_id: u32::MAX, - u32_suffix: port, - }); - } -} - -/// A safely aliasable heap-allocated payload of message bytes -#[derive(Default, Eq, PartialEq, Clone, Ord, PartialOrd)] -pub struct Payload(pub Arc>); -#[derive(Debug, Eq, PartialEq, Clone, Hash, Copy, Ord, PartialOrd)] - -/// "Orientation" of a port, determining whether they can send or receive messages with `put` and `get` respectively. -#[repr(C)] -#[derive(serde::Serialize, serde::Deserialize)] -pub enum Polarity { - Putter, // output port (from the perspective of the component) - Getter, // input port (from the perspective of the component) -} -#[derive(Debug, Eq, PartialEq, Clone, Hash, Copy, Ord, PartialOrd)] - -/// "Orientation" of a transport-layer network endpoint, dictating how it's connection procedure should -/// be conducted. Corresponds with connect() / accept() familiar to TCP socket programming. -#[repr(C)] -pub enum EndpointPolarity { - Active, // calls connect() - Passive, // calls bind() listen() accept() -} - -#[derive(Debug, Clone)] -pub(crate) enum NonsyncBlocker { - Inconsistent, - ComponentExit, - SyncBlockStart, -} -#[derive(Debug, Clone)] -pub(crate) enum SyncBlocker { - Inconsistent, - SyncBlockEnd, - CouldntReadMsg(PortId), - CouldntCheckFiring(PortId), - PutMsg(PortId, Payload), -} -pub(crate) struct DenseDebugHex<'a>(pub &'a [u8]); -pub(crate) struct DebuggableIter + Clone, T: Debug>(pub(crate) I); -///////////////////// IMPL ///////////////////// -impl IdParts for Id { - fn id_parts(self) -> (ConnectorId, U32Suffix) { - (self.connector_id, self.u32_suffix) - } -} -impl IdParts for PortId { - fn id_parts(self) -> (ConnectorId, U32Suffix) { - self.0.id_parts() - } -} -impl IdParts for ComponentId { - fn id_parts(self) -> (ConnectorId, U32Suffix) { - self.0.id_parts() - } -} -impl U32Stream { - pub(crate) fn next(&mut self) -> u32 { - if self.next == u32::MAX { - panic!("NO NEXT!") - } - self.next += 1; - self.next - 1 - } - pub(crate) fn n_skipped(mut self, n: u32) -> Self { - self.next = self.next.saturating_add(n); - self - } -} -impl From for PortId { - fn from(id: Id) -> PortId { - Self(id) - } -} -impl From for ComponentId { - fn from(id: Id) -> Self { - Self(id) - } -} -impl From<&[u8]> for Payload { - fn from(s: &[u8]) -> Payload { - Payload(Arc::new(s.to_vec())) - } -} -impl Payload { - /// Create a new payload of uninitialized bytes with the given length. - pub fn new(len: usize) -> Payload { - let mut v = Vec::with_capacity(len); - unsafe { - v.set_len(len); - } - Payload(Arc::new(v)) - } - /// Returns the length of the payload's byte sequence - pub fn len(&self) -> usize { - self.0.len() - } - /// Allows shared reading of the payload's contents - pub fn as_slice(&self) -> &[u8] { - &self.0 - } - - /// Allows mutation of the payload's contents. - /// Results in a deep copy in the event this payload is aliased. - pub fn as_mut_vec(&mut self) -> &mut Vec { - Arc::make_mut(&mut self.0) - } - - /// Modifies this payload, concatenating the given immutable payload's contents. - /// Results in a deep copy in the event this payload is aliased. - pub fn concatenate_with(&mut self, other: &Self) { - let bytes = other.as_slice().iter().copied(); - let me = self.as_mut_vec(); - me.extend(bytes); - } -} -impl serde::Serialize for Payload { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - let inner: &Vec = &self.0; - inner.serialize(serializer) - } -} -impl<'de> serde::Deserialize<'de> for Payload { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let inner: Vec = Vec::deserialize(deserializer)?; - Ok(Self(Arc::new(inner))) - } -} -impl From> for Payload { - fn from(s: Vec) -> Self { - Self(s.into()) - } -} -impl Debug for PortId { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - let (a, b) = self.id_parts(); - write!(f, "pid{}_{}", a, b) - } -} -impl Debug for ComponentId { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - let (a, b) = self.id_parts(); - write!(f, "cid{}_{}", a, b) - } -} -impl Debug for Payload { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - write!(f, "Payload[{:?}]", DenseDebugHex(self.as_slice())) - } -} -impl std::ops::Not for Polarity { - type Output = Self; - fn not(self) -> Self::Output { - use Polarity::*; - match self { - Putter => Getter, - Getter => Putter, - } - } -} -impl Debug for DenseDebugHex<'_> { - fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { - for b in self.0 { - write!(f, "{:02X?}", b)?; - } - Ok(()) - } -} - -impl + Clone, T: Debug> Debug for DebuggableIter { - fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { - f.debug_list().entries(self.0.clone()).finish() - } -} diff --git a/src/lib.rs b/src/lib.rs index fa75b6de1590bf488509d07a6a803de321356ece..69d61ed758a89489632c27ddfd5f90426d5d4fcc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,5 +6,6 @@ mod protocol; pub mod runtime; pub mod runtime2; mod collections; +mod random; pub use protocol::{ProtocolDescription, ProtocolDescriptionBuilder, ComponentCreationError}; \ No newline at end of file diff --git a/src/macros.rs b/src/macros.rs index 7777d1438c26ff8c770480bee68c179d5490e364..e27cb8c2d16d10bef8c5165b96b8d6d091bde3fe 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -18,4 +18,32 @@ macro_rules! dbg_code { ($code:stmt) => { #[cfg(debug_assertions)] $code } +} + +// Given a function name, return type and variant, will generate the all-so +// common `union_value.as_variant()` method. The return value is the reference +// to the embedded union type. +macro_rules! union_cast_to_ref_method_impl { + ($func_name:ident, $ret_type:ty, $variant:path) => { + fn $func_name(&self) -> &$ret_type { + match self { + $variant(content) => return content, + _ => unreachable!(), + } + } + } +} + +// Another union cast, but now returning a copy of the value +macro_rules! union_cast_to_value_method_impl { + ($func_name:ident, $ret_type:ty, $variant:path) => { + impl Value { + pub(crate) fn $func_name(&self) -> $ret_type { + match self { + $variant(v) => *v, + _ => unreachable!(), + } + } + } + } } \ No newline at end of file diff --git a/src/protocol/ast.rs b/src/protocol/ast.rs index da5ba6db6c6a05dde8bafa165679eda350dd7256..e137985fa110c6e2cea9a0f83be7cedf79aef88d 100644 --- a/src/protocol/ast.rs +++ b/src/protocol/ast.rs @@ -5,6 +5,7 @@ use std::ops::{Index, IndexMut}; use super::arena::{Arena, Id}; use crate::collections::StringRef; use crate::protocol::input_source::InputSpan; +use crate::protocol::TypeId; /// Helper macro that defines a type alias for a AST element ID. In this case /// only used to alias the `Id` types. @@ -117,8 +118,7 @@ define_aliased_ast_id!(DefinitionId, Id, index(Definition, definitio define_new_ast_id!(StructDefinitionId, DefinitionId, index(StructDefinition, Definition::Struct, definitions), alloc(alloc_struct_definition)); define_new_ast_id!(EnumDefinitionId, DefinitionId, index(EnumDefinition, Definition::Enum, definitions), alloc(alloc_enum_definition)); define_new_ast_id!(UnionDefinitionId, DefinitionId, index(UnionDefinition, Definition::Union, definitions), alloc(alloc_union_definition)); -define_new_ast_id!(ComponentDefinitionId, DefinitionId, index(ComponentDefinition, Definition::Component, definitions), alloc(alloc_component_definition)); -define_new_ast_id!(FunctionDefinitionId, DefinitionId, index(FunctionDefinition, Definition::Function, definitions), alloc(alloc_function_definition)); +define_new_ast_id!(ProcedureDefinitionId, DefinitionId, index(ProcedureDefinition, Definition::Procedure, definitions), alloc(alloc_procedure_definition)); define_aliased_ast_id!(StatementId, Id, index(Statement, statements)); define_new_ast_id!(BlockStatementId, StatementId, index(BlockStatement, Statement::Block, statements), alloc(alloc_block_statement)); @@ -158,6 +158,8 @@ define_new_ast_id!(CastExpressionId, ExpressionId, index(CastExpression, Express define_new_ast_id!(CallExpressionId, ExpressionId, index(CallExpression, Expression::Call, expressions), alloc(alloc_call_expression)); define_new_ast_id!(VariableExpressionId, ExpressionId, index(VariableExpression, Expression::Variable, expressions), alloc(alloc_variable_expression)); +define_aliased_ast_id!(ScopeId, Id, index(Scope, scopes), alloc(alloc_scope)); + #[derive(Debug)] pub struct Heap { // Root arena, contains the entry point for different modules. Each root @@ -170,6 +172,7 @@ pub struct Heap { pub(crate) definitions: Arena, pub(crate) statements: Arena, pub(crate) expressions: Arena, + pub(crate) scopes: Arena, } impl Heap { @@ -183,6 +186,7 @@ impl Heap { definitions: Arena::new(), statements: Arena::new(), expressions: Arena::new(), + scopes: Arena::new(), } } pub fn alloc_memory_statement( @@ -210,14 +214,20 @@ impl Heap { impl Index for Heap { type Output = MemoryStatement; fn index(&self, index: MemoryStatementId) -> &Self::Output { - &self.statements[index.0.0].as_memory() + match &self.statements[index.0.0] { + Statement::Local(LocalStatement::Memory(v)) => v, + _ => unreachable!(), + } } } impl Index for Heap { type Output = ChannelStatement; fn index(&self, index: ChannelStatementId) -> &Self::Output { - &self.statements[index.0.0].as_channel() + match &self.statements[index.0.0] { + Statement::Local(LocalStatement::Channel(v)) => v, + _ => unreachable!(), + } } } @@ -340,6 +350,15 @@ pub struct Identifier { pub value: StringRef<'static>, } +impl Identifier { + pub(crate) const fn new_empty(span: InputSpan) -> Identifier { + return Identifier{ + span, + value: StringRef::new_empty(), + }; + } +} + impl PartialEq for Identifier { fn eq(&self, other: &Self) -> bool { return self.value == other.value @@ -497,12 +516,13 @@ pub enum ConcreteTypePart { Slice, Input, Output, + Pointer, // Tuple: variable number of nested types, will never be 1 Tuple(u32), // User defined type with any number of nested types Instance(DefinitionId, u32), // instance of data type - Function(DefinitionId, u32), // instance of function - Component(DefinitionId, u32), // instance of a connector + Function(ProcedureDefinitionId, u32), // instance of function + Component(ProcedureDefinitionId, u32), // instance of a connector } impl ConcreteTypePart { @@ -515,7 +535,7 @@ impl ConcreteTypePart { SInt8 | SInt16 | SInt32 | SInt64 | Character | String => 0, - Array | Slice | Input | Output => + Array | Slice | Input | Output | Pointer => 1, Tuple(num_embedded) => *num_embedded, Instance(_, num_embedded) => *num_embedded, @@ -622,6 +642,10 @@ impl ConcreteType { idx = Self::render_type_part_at(parts, heap, idx, target); target.push('>'); }, + CTP::Pointer => { + target.push('*'); + idx = Self::render_type_part_at(parts, heap, idx, target); + } CTP::Tuple(num_parts) => { target.push('('); if num_parts != 0 { @@ -633,27 +657,35 @@ impl ConcreteType { } target.push(')'); }, - CTP::Instance(definition_id, num_poly_args) | + CTP::Instance(definition_id, num_poly_args) => { + idx = Self::render_definition_type_parts_at(parts, heap, definition_id, num_poly_args, idx, target); + } CTP::Function(definition_id, num_poly_args) | CTP::Component(definition_id, num_poly_args) => { - let definition = &heap[definition_id]; - target.push_str(definition.identifier().value.as_str()); - - if num_poly_args != 0 { - target.push('<'); - for poly_arg_idx in 0..num_poly_args { - if poly_arg_idx != 0 { - target.push(','); - } - idx = Self::render_type_part_at(parts, heap, idx, target); - } - target.push('>'); - } + idx = Self::render_definition_type_parts_at(parts, heap, definition_id.upcast(), num_poly_args, idx, target); } } idx } + + fn render_definition_type_parts_at(parts: &[ConcreteTypePart], heap: &Heap, definition_id: DefinitionId, num_poly_args: u32, mut idx: usize, target: &mut String) -> usize { + let definition = &heap[definition_id]; + target.push_str(definition.identifier().value.as_str()); + + if num_poly_args != 0 { + target.push('<'); + for poly_arg_idx in 0..num_poly_args { + if poly_arg_idx != 0 { + target.push(','); + } + idx = Self::render_type_part_at(parts, heap, idx, target); + } + target.push('>'); + } + + return idx; + } } #[derive(Debug)] @@ -696,60 +728,66 @@ impl<'a> Iterator for ConcreteTypeIter<'a> { } #[derive(Debug, Clone, Copy)] -pub enum Scope { +pub enum ScopeAssociation { Definition(DefinitionId), - Regular(BlockStatementId), - Synchronous(SynchronousStatementId, BlockStatementId), -} - -impl Scope { - pub(crate) fn new_invalid() -> Scope { - return Scope::Definition(DefinitionId::new_invalid()); - } - - pub(crate) fn is_invalid(&self) -> bool { - match self { - Scope::Definition(id) => id.is_invalid(), - _ => false, - } - } - - pub fn is_block(&self) -> bool { - match &self { - Scope::Definition(_) => false, - Scope::Regular(_) => true, - Scope::Synchronous(_, _) => true, - } - } - pub fn to_block(&self) -> BlockStatementId { - match &self { - Scope::Regular(id) => *id, - Scope::Synchronous(_, id) => *id, - _ => panic!("unable to get BlockStatement from Scope") - } - } + Block(BlockStatementId), + If(IfStatementId, bool), // if true, then body of "if", otherwise body of "else" + While(WhileStatementId), + Synchronous(SynchronousStatementId), + SelectCase(SelectStatementId, u32), // index is select case } /// `ScopeNode` is a helper that links scopes in two directions. It doesn't /// actually contain any information associated with the scope, this may be /// found on the AST elements that `Scope` points to. #[derive(Debug, Clone)] -pub struct ScopeNode { - pub parent: Scope, - pub nested: Vec, +pub struct Scope { + // Relation to other scopes + pub this: ScopeId, + pub parent: Option, + pub nested: Vec, + // Locally available variables/labels + pub association: ScopeAssociation, + pub variables: Vec, + pub labels: Vec, + // Location trackers/counters pub relative_pos_in_parent: i32, + pub first_unique_id_in_scope: i32, + pub next_unique_id_in_scope: i32, } -impl ScopeNode { - pub(crate) fn new_invalid() -> Self { - ScopeNode{ - parent: Scope::new_invalid(), +impl Scope { + pub(crate) fn new(id: ScopeId, association: ScopeAssociation) -> Self { + return Self{ + this: id, + parent: None, nested: Vec::new(), + association, + variables: Vec::new(), + labels: Vec::new(), relative_pos_in_parent: -1, + first_unique_id_in_scope: -1, + next_unique_id_in_scope: -1, } } } +impl Scope { + pub(crate) fn new_invalid(this: ScopeId) -> Self { + return Self{ + this, + parent: None, + nested: Vec::new(), + association: ScopeAssociation::Definition(DefinitionId::new_invalid()), + variables: Vec::new(), + labels: Vec::new(), + relative_pos_in_parent: -1, + first_unique_id_in_scope: -1, + next_unique_id_in_scope: -1, + }; + } +} + #[derive(Debug, Clone, PartialEq, Eq)] pub enum VariableKind { Parameter, // in parameter list of function/component @@ -765,17 +803,16 @@ pub struct Variable { pub parser_type: ParserType, pub identifier: Identifier, // Validator/linker - pub relative_pos_in_block: i32, + pub relative_pos_in_parent: i32, pub unique_id_in_scope: i32, // Temporary fix until proper bytecode/asm is generated } -#[derive(Debug, Clone)] +#[derive(Debug)] pub enum Definition { Struct(StructDefinition), Enum(EnumDefinition), Union(UnionDefinition), - Component(ComponentDefinition), - Function(FunctionDefinition), + Procedure(ProcedureDefinition), } impl Definition { @@ -827,71 +864,50 @@ impl Definition { _ => panic!("Unable to cast 'Definition' to 'UnionDefinition'"), } } + pub(crate) fn as_union_mut(&mut self) -> &mut UnionDefinition { match self { Definition::Union(result) => result, _ => panic!("Unable to cast 'Definition' to 'UnionDefinition'"), } } - pub fn is_component(&self) -> bool { - match self { - Definition::Component(_) => true, - _ => false, - } - } - pub(crate) fn as_component(&self) -> &ComponentDefinition { - match self { - Definition::Component(result) => result, - _ => panic!("Unable to cast `Definition` to `Component`"), - } - } - pub(crate) fn as_component_mut(&mut self) -> &mut ComponentDefinition { - match self { - Definition::Component(result) => result, - _ => panic!("Unable to cast `Definition` to `Component`"), - } - } - pub fn is_function(&self) -> bool { + + pub fn is_procedure(&self) -> bool { match self { - Definition::Function(_) => true, + Definition::Procedure(_) => true, _ => false, } } - pub(crate) fn as_function(&self) -> &FunctionDefinition { + + pub(crate) fn as_procedure(&self) -> &ProcedureDefinition { match self { - Definition::Function(result) => result, + Definition::Procedure(result) => result, _ => panic!("Unable to cast `Definition` to `Function`"), } } - pub(crate) fn as_function_mut(&mut self) -> &mut FunctionDefinition { + + pub(crate) fn as_procedure_mut(&mut self) -> &mut ProcedureDefinition { match self { - Definition::Function(result) => result, + Definition::Procedure(result) => result, _ => panic!("Unable to cast `Definition` to `Function`"), } } - pub fn parameters(&self) -> &Vec { - match self { - Definition::Component(def) => &def.parameters, - Definition::Function(def) => &def.parameters, - _ => panic!("Called parameters() on {:?}", self) - } - } + pub fn defined_in(&self) -> RootId { match self { Definition::Struct(def) => def.defined_in, Definition::Enum(def) => def.defined_in, Definition::Union(def) => def.defined_in, - Definition::Component(def) => def.defined_in, - Definition::Function(def) => def.defined_in, + Definition::Procedure(def) => def.defined_in, } } + pub fn identifier(&self) -> &Identifier { match self { Definition::Struct(def) => &def.identifier, Definition::Enum(def) => &def.identifier, Definition::Union(def) => &def.identifier, - Definition::Component(def) => &def.identifier, - Definition::Function(def) => &def.identifier, + Definition::Procedure(def) => &def.identifier, } } pub fn poly_vars(&self) -> &Vec { @@ -899,8 +915,7 @@ impl Definition { Definition::Struct(def) => &def.poly_vars, Definition::Enum(def) => &def.poly_vars, Definition::Union(def) => &def.poly_vars, - Definition::Component(def) => &def.poly_vars, - Definition::Function(def) => &def.poly_vars, + Definition::Procedure(def) => &def.poly_vars, } } } @@ -994,75 +1009,106 @@ impl UnionDefinition { } } -#[derive(Debug, Clone, Copy)] -pub enum ComponentVariant { - Primitive, +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ProcedureKind { + Function, // with return type + Primitive, // without return type Composite, } -#[derive(Debug, Clone)] -pub struct ComponentDefinition { - pub this: ComponentDefinitionId, - pub defined_in: RootId, - // Symbol scanning - pub span: InputSpan, - pub variant: ComponentVariant, - pub identifier: Identifier, - pub poly_vars: Vec, - // Parsing - pub parameters: Vec, - pub body: BlockStatementId, - // Validation/linking - pub num_expressions_in_body: i32, +/// Monomorphed instantiation of a procedure (or the sole instantiation of a +/// non-polymorphic procedure). +#[derive(Debug)] +pub struct ProcedureDefinitionMonomorph { + pub argument_types: Vec, + pub expr_info: Vec } -impl ComponentDefinition { - // Used for preallocation during symbol scanning - pub(crate) fn new_empty( - this: ComponentDefinitionId, defined_in: RootId, span: InputSpan, - variant: ComponentVariant, identifier: Identifier, poly_vars: Vec - ) -> Self { - Self{ - this, defined_in, span, variant, identifier, poly_vars, - parameters: Vec::new(), - body: BlockStatementId::new_invalid(), - num_expressions_in_body: -1, +impl ProcedureDefinitionMonomorph { + pub(crate) fn new_invalid() -> Self { + return Self{ + argument_types: Vec::new(), + expr_info: Vec::new(), + } + } +} + +#[derive(Debug, Clone, Copy)] +pub struct ExpressionInfo { + pub type_id: TypeId, + pub variant: ExpressionInfoVariant, +} + +impl ExpressionInfo { + pub(crate) fn new_invalid() -> Self { + return Self{ + type_id: TypeId::new_invalid(), + variant: ExpressionInfoVariant::Generic, + } + } +} + +#[derive(Debug, Clone, Copy)] +pub enum ExpressionInfoVariant { + Generic, + Procedure(TypeId, u32), // procedure TypeID and its monomorph index + Select(i32), // index +} + +impl ExpressionInfoVariant { + pub(crate) fn as_select(&self) -> i32 { + match self { + ExpressionInfoVariant::Select(v) => *v, + _ => unreachable!(), + } + } + + pub(crate) fn as_procedure(&self) -> (TypeId, u32) { + match self { + ExpressionInfoVariant::Procedure(type_id, monomorph_index) => (*type_id, *monomorph_index), + _ => unreachable!(), } } } +/// Generic storage for functions, primitive components and composite +/// components. // Note that we will have function definitions for builtin functions as well. In // that case the span, the identifier span and the body are all invalid. -#[derive(Debug, Clone)] -pub struct FunctionDefinition { - pub this: FunctionDefinitionId, +#[derive(Debug)] +pub struct ProcedureDefinition { + pub this: ProcedureDefinitionId, pub defined_in: RootId, // Symbol scanning pub builtin: bool, + pub kind: ProcedureKind, pub span: InputSpan, pub identifier: Identifier, pub poly_vars: Vec, // Parser - pub return_types: Vec, + pub return_type: Option, // present on functions, not components pub parameters: Vec, + pub scope: ScopeId, pub body: BlockStatementId, - // Validation/linking - pub num_expressions_in_body: i32, + // Monomorphization of typed procedures + pub monomorphs: Vec, } -impl FunctionDefinition { +impl ProcedureDefinition { pub(crate) fn new_empty( - this: FunctionDefinitionId, defined_in: RootId, span: InputSpan, - identifier: Identifier, poly_vars: Vec + this: ProcedureDefinitionId, defined_in: RootId, span: InputSpan, + kind: ProcedureKind, identifier: Identifier, poly_vars: Vec ) -> Self { Self { this, defined_in, builtin: false, - span, identifier, poly_vars, - return_types: Vec::new(), + span, + kind, identifier, poly_vars, + return_type: None, parameters: Vec::new(), + scope: ScopeId::new_invalid(), body: BlockStatementId::new_invalid(), - num_expressions_in_body: -1, + monomorphs: Vec::new(), } } } @@ -1092,25 +1138,6 @@ pub enum Statement { } impl Statement { - pub fn as_block(&self) -> &BlockStatement { - match self { - Statement::Block(result) => result, - _ => panic!("Unable to cast `Statement` to `BlockStatement`"), - } - } - pub fn as_local(&self) -> &LocalStatement { - match self { - Statement::Local(result) => result, - _ => panic!("Unable to cast `Statement` to `LocalStatement`"), - } - } - pub fn as_memory(&self) -> &MemoryStatement { - self.as_local().as_memory() - } - pub fn as_channel(&self) -> &ChannelStatement { - self.as_local().as_channel() - } - pub fn as_new(&self) -> &NewStatement { match self { Statement::New(result) => result, @@ -1169,22 +1196,18 @@ impl Statement { | Statement::If(_) => unreachable!(), } } + } #[derive(Debug, Clone)] pub struct BlockStatement { pub this: BlockStatementId, // Phase 1: parser - pub is_implicit: bool, pub span: InputSpan, // of the complete block pub statements: Vec, pub end_block: EndBlockStatementId, // Phase 2: linker - pub scope_node: ScopeNode, - pub first_unique_id_in_scope: i32, // Temporary fix until proper bytecode/asm is generated - pub next_unique_id_in_scope: i32, // Temporary fix until proper bytecode/asm is generated - pub locals: Vec, - pub labels: Vec, + pub scope: ScopeId, pub next: StatementId, } @@ -1210,18 +1233,6 @@ impl LocalStatement { LocalStatement::Channel(stmt) => stmt.this.upcast(), } } - pub fn as_memory(&self) -> &MemoryStatement { - match self { - LocalStatement::Memory(result) => result, - _ => panic!("Unable to cast `LocalStatement` to `MemoryStatement`"), - } - } - pub fn as_channel(&self) -> &ChannelStatement { - match self { - LocalStatement::Channel(result) => result, - _ => panic!("Unable to cast `LocalStatement` to `ChannelStatement`"), - } - } pub fn span(&self) -> InputSpan { match self { LocalStatement::Channel(v) => v.span, @@ -1254,7 +1265,7 @@ pub struct ChannelStatement { pub from: VariableId, // output pub to: VariableId, // input // Phase 2: linker - pub relative_pos_in_block: i32, + pub relative_pos_in_parent: i32, pub next: StatementId, } @@ -1265,7 +1276,7 @@ pub struct LabeledStatement { pub label: Identifier, pub body: StatementId, // Phase 2: linker - pub relative_pos_in_block: i32, + pub relative_pos_in_parent: i32, pub in_sync: SynchronousStatementId, // may be invalid } @@ -1275,11 +1286,17 @@ pub struct IfStatement { // Phase 1: parser pub span: InputSpan, // of the "if" keyword pub test: ExpressionId, - pub true_body: BlockStatementId, - pub false_body: Option, + pub true_case: IfStatementCase, + pub false_case: Option, pub end_if: EndIfStatementId, } +#[derive(Debug, Clone, Copy)] +pub struct IfStatementCase { + pub body: StatementId, + pub scope: ScopeId, +} + #[derive(Debug, Clone)] pub struct EndIfStatement { pub this: EndIfStatementId, @@ -1293,7 +1310,8 @@ pub struct WhileStatement { // Phase 1: parser pub span: InputSpan, // of the "while" keyword pub test: ExpressionId, - pub body: BlockStatementId, + pub scope: ScopeId, + pub body: StatementId, pub end_while: EndWhileStatementId, pub in_sync: SynchronousStatementId, // may be invalid } @@ -1331,7 +1349,8 @@ pub struct SynchronousStatement { pub this: SynchronousStatementId, // Phase 1: parser pub span: InputSpan, // of the "sync" keyword - pub body: BlockStatementId, + pub scope: ScopeId, + pub body: StatementId, pub end_sync: EndSynchronousStatementId, } @@ -1348,8 +1367,8 @@ pub struct ForkStatement { pub this: ForkStatementId, // Phase 1: parser pub span: InputSpan, // of the "fork" keyword - pub left_body: BlockStatementId, - pub right_body: Option, + pub left_body: StatementId, + pub right_body: Option, pub end_fork: EndForkStatementId, } @@ -1366,6 +1385,8 @@ pub struct SelectStatement { pub span: InputSpan, // of the "select" keyword pub cases: Vec, pub end_select: EndSelectStatementId, + pub relative_pos_in_parent: i32, + pub next: StatementId, // note: the select statement will be transformed into other AST elements, this `next` jumps to those replacement statements } #[derive(Debug, Clone)] @@ -1373,7 +1394,8 @@ pub struct SelectCase { // The guard statement of a `select` is either a MemoryStatement or an // ExpressionStatement. Nothing else is allowed by the initial parsing pub guard: StatementId, - pub block: BlockStatementId, + pub body: StatementId, + pub scope: ScopeId, // Phase 2: Validation and Linking pub involved_ports: Vec<(CallExpressionId, ExpressionId)>, // call to `get` and its port argument } @@ -1432,7 +1454,7 @@ pub enum ExpressionParent { Return(ReturnStatementId), New(NewStatementId), ExpressionStmt(ExpressionStatementId), - Expression(ExpressionId, u32) // index within expression (e.g LHS or RHS of expression) + Expression(ExpressionId, u32) // index within expression (e.g LHS or RHS of expression, or index in array literal, etc.) } impl ExpressionParent { @@ -1530,6 +1552,23 @@ impl Expression { } } + pub fn parent_mut(&mut self) -> &mut ExpressionParent { + match self { + Expression::Assignment(expr) => &mut expr.parent, + Expression::Binding(expr) => &mut expr.parent, + Expression::Conditional(expr) => &mut expr.parent, + Expression::Binary(expr) => &mut expr.parent, + Expression::Unary(expr) => &mut expr.parent, + Expression::Indexing(expr) => &mut expr.parent, + Expression::Slicing(expr) => &mut expr.parent, + Expression::Select(expr) => &mut expr.parent, + Expression::Literal(expr) => &mut expr.parent, + Expression::Cast(expr) => &mut expr.parent, + Expression::Call(expr) => &mut expr.parent, + Expression::Variable(expr) => &mut expr.parent, + } + } + pub fn parent_expr_id(&self) -> Option { if let ExpressionParent::Expression(id, _) = self.parent() { Some(*id) @@ -1538,20 +1577,37 @@ impl Expression { } } - pub fn get_unique_id_in_definition(&self) -> i32 { + pub fn type_index(&self) -> i32 { + match self { + Expression::Assignment(expr) => expr.type_index, + Expression::Binding(expr) => expr.type_index, + Expression::Conditional(expr) => expr.type_index, + Expression::Binary(expr) => expr.type_index, + Expression::Unary(expr) => expr.type_index, + Expression::Indexing(expr) => expr.type_index, + Expression::Slicing(expr) => expr.type_index, + Expression::Select(expr) => expr.type_index, + Expression::Literal(expr) => expr.type_index, + Expression::Cast(expr) => expr.type_index, + Expression::Call(expr) => expr.type_index, + Expression::Variable(expr) => expr.type_index, + } + } + + pub fn type_index_mut(&mut self) -> &mut i32 { match self { - Expression::Assignment(expr) => expr.unique_id_in_definition, - Expression::Binding(expr) => expr.unique_id_in_definition, - Expression::Conditional(expr) => expr.unique_id_in_definition, - Expression::Binary(expr) => expr.unique_id_in_definition, - Expression::Unary(expr) => expr.unique_id_in_definition, - Expression::Indexing(expr) => expr.unique_id_in_definition, - Expression::Slicing(expr) => expr.unique_id_in_definition, - Expression::Select(expr) => expr.unique_id_in_definition, - Expression::Literal(expr) => expr.unique_id_in_definition, - Expression::Cast(expr) => expr.unique_id_in_definition, - Expression::Call(expr) => expr.unique_id_in_definition, - Expression::Variable(expr) => expr.unique_id_in_definition, + Expression::Assignment(expr) => &mut expr.type_index, + Expression::Binding(expr) => &mut expr.type_index, + Expression::Conditional(expr) => &mut expr.type_index, + Expression::Binary(expr) => &mut expr.type_index, + Expression::Unary(expr) => &mut expr.type_index, + Expression::Indexing(expr) => &mut expr.type_index, + Expression::Slicing(expr) => &mut expr.type_index, + Expression::Select(expr) => &mut expr.type_index, + Expression::Literal(expr) => &mut expr.type_index, + Expression::Cast(expr) => &mut expr.type_index, + Expression::Call(expr) => &mut expr.type_index, + Expression::Variable(expr) => &mut expr.type_index, } } } @@ -1583,7 +1639,8 @@ pub struct AssignmentExpression { pub right: ExpressionId, // Validator/Linker pub parent: ExpressionParent, - pub unique_id_in_definition: i32, + // Typing + pub type_index: i32, } #[derive(Debug, Clone)] @@ -1596,7 +1653,8 @@ pub struct BindingExpression { pub bound_from: ExpressionId, // Validator/Linker pub parent: ExpressionParent, - pub unique_id_in_definition: i32, + // Typing + pub type_index: i32, } #[derive(Debug, Clone)] @@ -1610,7 +1668,8 @@ pub struct ConditionalExpression { pub false_expression: ExpressionId, // Validator/Linking pub parent: ExpressionParent, - pub unique_id_in_definition: i32, + // Typing + pub type_index: i32, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -1647,7 +1706,8 @@ pub struct BinaryExpression { pub right: ExpressionId, // Validator/Linker pub parent: ExpressionParent, - pub unique_id_in_definition: i32, + // Typing + pub type_index: i32, } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -1668,7 +1728,8 @@ pub struct UnaryExpression { pub expression: ExpressionId, // Validator/Linker pub parent: ExpressionParent, - pub unique_id_in_definition: i32, + // Typing + pub type_index: i32, } #[derive(Debug, Clone)] @@ -1681,7 +1742,8 @@ pub struct IndexingExpression { pub index: ExpressionId, // Validator/Linker pub parent: ExpressionParent, - pub unique_id_in_definition: i32, + // Typing + pub type_index: i32, } #[derive(Debug, Clone)] @@ -1695,7 +1757,8 @@ pub struct SlicingExpression { pub to_index: ExpressionId, // Validator/Linker pub parent: ExpressionParent, - pub unique_id_in_definition: i32, + // Typing + pub type_index: i32, } #[derive(Debug, Clone)] @@ -1714,7 +1777,8 @@ pub struct SelectExpression { pub kind: SelectKind, // Validator/Linker pub parent: ExpressionParent, - pub unique_id_in_definition: i32, + // Typing + pub type_index: i32, } #[derive(Debug, Clone)] @@ -1727,7 +1791,8 @@ pub struct CastExpression { pub subject: ExpressionId, // Validator/linker pub parent: ExpressionParent, - pub unique_id_in_definition: i32, + // Typing + pub type_index: i32, } #[derive(Debug, Clone)] @@ -1739,15 +1804,16 @@ pub struct CallExpression { pub parser_type: ParserType, // of the function call, not the return type pub method: Method, pub arguments: Vec, - pub definition: DefinitionId, + pub procedure: ProcedureDefinitionId, // Validator/Linker pub parent: ExpressionParent, - pub unique_id_in_definition: i32, + // Typing + pub type_index: i32, } #[derive(Debug, Clone, PartialEq, Eq)] pub enum Method { - // Builtin + // Builtin, accessible by programmer Get, Put, Fires, @@ -1755,14 +1821,31 @@ pub enum Method { Length, Assert, Print, + // Builtin, not accessible by programmer + SelectStart, // SelectStart(total_num_cases, total_num_ports) + SelectRegisterCasePort, // SelectRegisterCasePort(case_index, port_index, port_id) + SelectWait, // SelectWait() -> u32 + // User-defined UserFunction, UserComponent, } -#[derive(Debug, Clone)] -pub struct MethodSymbolic { - pub(crate) parser_type: ParserType, - pub(crate) definition: DefinitionId +impl Method { + pub(crate) fn is_public_builtin(&self) -> bool { + use Method::*; + match self { + Get | Put | Fires | Create | Length | Assert | Print => true, + _ => false, + } + } + + pub(crate) fn is_user_defined(&self) -> bool { + use Method::*; + match self { + UserFunction | UserComponent => true, + _ => false, + } + } } #[derive(Debug, Clone)] @@ -1773,7 +1856,8 @@ pub struct LiteralExpression { pub value: Literal, // Validator/Linker pub parent: ExpressionParent, - pub unique_id_in_definition: i32, + // Typing + pub type_index: i32, } #[derive(Debug, Clone)] @@ -1870,5 +1954,6 @@ pub struct VariableExpression { pub declaration: Option, pub used_as_binding_target: bool, pub parent: ExpressionParent, - pub unique_id_in_definition: i32, + // Typing + pub type_index: i32, } \ No newline at end of file diff --git a/src/protocol/ast_printer.rs b/src/protocol/ast_printer.rs index 5529cd3dde91d3b9834ab3cc4ecce3d8641eef9d..22a5db0376dda2d918d174df982d37d19cec4970 100644 --- a/src/protocol/ast_printer.rs +++ b/src/protocol/ast_printer.rs @@ -345,7 +345,7 @@ impl ASTWriter { } } } - Definition::Function(def) => { + Definition::Procedure(def) => { self.kv(indent).with_id(PREFIX_FUNCTION_ID, def.this.0.index) .with_s_key("DefinitionFunction"); @@ -354,10 +354,10 @@ impl ASTWriter { self.kv(indent3).with_s_key("PolyVar").with_identifier_val(&poly_var_id); } - self.kv(indent2).with_s_key("ReturnParserTypes"); - for return_type in &def.return_types { - self.kv(indent3).with_s_key("ReturnParserType") - .with_custom_val(|s| write_parser_type(s, heap, return_type)); + self.kv(indent2).with_s_key("Kind").with_debug_val(&def.kind); + if let Some(parser_type) = &def.return_type { + self.kv(indent2).with_s_key("ReturnParserType") + .with_custom_val(|s| write_parser_type(s, heap, parser_type)); } self.kv(indent2).with_s_key("Parameters"); @@ -368,26 +368,6 @@ impl ASTWriter { self.kv(indent2).with_s_key("Body"); self.write_stmt(heap, def.body.upcast(), indent3); }, - Definition::Component(def) => { - self.kv(indent).with_id(PREFIX_COMPONENT_ID,def.this.0.index) - .with_s_key("DefinitionComponent"); - - self.kv(indent2).with_s_key("Name").with_identifier_val(&def.identifier); - self.kv(indent2).with_s_key("Variant").with_debug_val(&def.variant); - - self.kv(indent2).with_s_key("PolymorphicVariables"); - for poly_var_id in &def.poly_vars { - self.kv(indent3).with_s_key("PolyVar").with_identifier_val(&poly_var_id); - } - - self.kv(indent2).with_s_key("Parameters"); - for variable_id in &def.parameters { - self.write_variable(heap, *variable_id, indent3) - } - - self.kv(indent2).with_s_key("Body"); - self.write_stmt(heap, def.body.upcast(), indent3); - } } } @@ -401,9 +381,7 @@ impl ASTWriter { self.kv(indent).with_id(PREFIX_BLOCK_STMT_ID, stmt.this.0.index) .with_s_key("Block"); self.kv(indent2).with_s_key("EndBlockID").with_disp_val(&stmt.end_block.0.index); - self.kv(indent2).with_s_key("FirstUniqueScopeID").with_disp_val(&stmt.first_unique_id_in_scope); - self.kv(indent2).with_s_key("NextUniqueScopeID").with_disp_val(&stmt.next_unique_id_in_scope); - self.kv(indent2).with_s_key("RelativePos").with_disp_val(&stmt.scope_node.relative_pos_in_parent); + self.kv(indent2).with_s_key("ScopeID").with_disp_val(&stmt.scope.index); self.kv(indent2).with_s_key("Statements"); for stmt_id in &stmt.statements { @@ -457,11 +435,11 @@ impl ASTWriter { self.write_expr(heap, stmt.test, indent3); self.kv(indent2).with_s_key("TrueBody"); - self.write_stmt(heap, stmt.true_body.upcast(), indent3); + self.write_stmt(heap, stmt.true_case.body, indent3); - if let Some(false_body) = stmt.false_body { + if let Some(false_body) = stmt.false_case { self.kv(indent2).with_s_key("FalseBody"); - self.write_stmt(heap, false_body.upcast(), indent3); + self.write_stmt(heap, false_body.body, indent3); } }, Statement::EndIf(stmt) => { @@ -480,7 +458,7 @@ impl ASTWriter { self.kv(indent2).with_s_key("Condition"); self.write_expr(heap, stmt.test, indent3); self.kv(indent2).with_s_key("Body"); - self.write_stmt(heap, stmt.body.upcast(), indent3); + self.write_stmt(heap, stmt.body, indent3); }, Statement::EndWhile(stmt) => { self.kv(indent).with_id(PREFIX_ENDWHILE_STMT_ID, stmt.this.0.index) @@ -509,7 +487,7 @@ impl ASTWriter { .with_s_key("Synchronous"); self.kv(indent2).with_s_key("EndSync").with_disp_val(&stmt.end_sync.0.index); self.kv(indent2).with_s_key("Body"); - self.write_stmt(heap, stmt.body.upcast(), indent3); + self.write_stmt(heap, stmt.body, indent3); }, Statement::EndSynchronous(stmt) => { self.kv(indent).with_id(PREFIX_ENDSYNC_STMT_ID, stmt.this.0.index) @@ -522,11 +500,11 @@ impl ASTWriter { .with_s_key("Fork"); self.kv(indent2).with_s_key("EndFork").with_disp_val(&stmt.end_fork.0.index); self.kv(indent2).with_s_key("LeftBody"); - self.write_stmt(heap, stmt.left_body.upcast(), indent3); + self.write_stmt(heap, stmt.left_body, indent3); if let Some(right_body_id) = stmt.right_body { self.kv(indent2).with_s_key("RightBody"); - self.write_stmt(heap, right_body_id.upcast(), indent3); + self.write_stmt(heap, right_body_id, indent3); } }, Statement::EndFork(stmt) => { @@ -547,8 +525,10 @@ impl ASTWriter { self.write_stmt(heap, case.guard, indent4); self.kv(indent3).with_s_key("Block"); - self.write_stmt(heap, case.block.upcast(), indent4); + self.write_stmt(heap, case.body, indent4); } + self.kv(indent2).with_s_key("Replacement"); + self.write_stmt(heap, stmt.next, indent3); }, Statement::EndSelect(stmt) => { self.kv(indent).with_id(PREFIX_END_SELECT_STMT_ID, stmt.this.0.index) @@ -596,6 +576,7 @@ impl ASTWriter { Expression::Assignment(expr) => { self.kv(indent).with_id(PREFIX_ASSIGNMENT_EXPR_ID, expr.this.0.index) .with_s_key("AssignmentExpr"); + self.kv(indent2).with_s_key("TypeIndex").with_disp_val(&expr.type_index); self.kv(indent2).with_s_key("Operation").with_debug_val(&expr.operation); self.kv(indent2).with_s_key("Left"); self.write_expr(heap, expr.left, indent3); @@ -607,6 +588,7 @@ impl ASTWriter { Expression::Binding(expr) => { self.kv(indent).with_id(PREFIX_BINARY_EXPR_ID, expr.this.0.index) .with_s_key("BindingExpr"); + self.kv(indent2).with_s_key("TypeIndex").with_disp_val(&expr.type_index); self.kv(indent2).with_s_key("BindToExpression"); self.write_expr(heap, expr.bound_to, indent3); self.kv(indent2).with_s_key("BindFromExpression"); @@ -617,6 +599,7 @@ impl ASTWriter { Expression::Conditional(expr) => { self.kv(indent).with_id(PREFIX_CONDITIONAL_EXPR_ID, expr.this.0.index) .with_s_key("ConditionalExpr"); + self.kv(indent2).with_s_key("TypeIndex").with_disp_val(&expr.type_index); self.kv(indent2).with_s_key("Condition"); self.write_expr(heap, expr.test, indent3); self.kv(indent2).with_s_key("TrueExpression"); @@ -629,6 +612,7 @@ impl ASTWriter { Expression::Binary(expr) => { self.kv(indent).with_id(PREFIX_BINARY_EXPR_ID, expr.this.0.index) .with_s_key("BinaryExpr"); + self.kv(indent2).with_s_key("TypeIndex").with_disp_val(&expr.type_index); self.kv(indent2).with_s_key("Operation").with_debug_val(&expr.operation); self.kv(indent2).with_s_key("Left"); self.write_expr(heap, expr.left, indent3); @@ -640,6 +624,7 @@ impl ASTWriter { Expression::Unary(expr) => { self.kv(indent).with_id(PREFIX_UNARY_EXPR_ID, expr.this.0.index) .with_s_key("UnaryExpr"); + self.kv(indent2).with_s_key("TypeIndex").with_disp_val(&expr.type_index); self.kv(indent2).with_s_key("Operation").with_debug_val(&expr.operation); self.kv(indent2).with_s_key("Argument"); self.write_expr(heap, expr.expression, indent3); @@ -649,6 +634,7 @@ impl ASTWriter { Expression::Indexing(expr) => { self.kv(indent).with_id(PREFIX_INDEXING_EXPR_ID, expr.this.0.index) .with_s_key("IndexingExpr"); + self.kv(indent2).with_s_key("TypeIndex").with_disp_val(&expr.type_index); self.kv(indent2).with_s_key("Subject"); self.write_expr(heap, expr.subject, indent3); self.kv(indent2).with_s_key("Index"); @@ -659,6 +645,7 @@ impl ASTWriter { Expression::Slicing(expr) => { self.kv(indent).with_id(PREFIX_SLICING_EXPR_ID, expr.this.0.index) .with_s_key("SlicingExpr"); + self.kv(indent2).with_s_key("TypeIndex").with_disp_val(&expr.type_index); self.kv(indent2).with_s_key("Subject"); self.write_expr(heap, expr.subject, indent3); self.kv(indent2).with_s_key("FromIndex"); @@ -671,6 +658,7 @@ impl ASTWriter { Expression::Select(expr) => { self.kv(indent).with_id(PREFIX_SELECT_EXPR_ID, expr.this.0.index) .with_s_key("SelectExpr"); + self.kv(indent2).with_s_key("TypeIndex").with_disp_val(&expr.type_index); self.kv(indent2).with_s_key("Subject"); self.write_expr(heap, expr.subject, indent3); @@ -690,6 +678,7 @@ impl ASTWriter { self.kv(indent).with_id(PREFIX_LITERAL_EXPR_ID, expr.this.0.index) .with_s_key("LiteralExpr"); + self.kv(indent2).with_s_key("TypeIndex").with_disp_val(&expr.type_index); let val = self.kv(indent2).with_s_key("Value"); match &expr.value { Literal::Null => { val.with_s_val("null"); }, @@ -765,6 +754,7 @@ impl ASTWriter { Expression::Cast(expr) => { self.kv(indent).with_id(PREFIX_CAST_EXPR_ID, expr.this.0.index) .with_s_key("CallExpr"); + self.kv(indent2).with_s_key("TypeIndex").with_disp_val(&expr.type_index); self.kv(indent2).with_s_key("ToType") .with_custom_val(|t| write_parser_type(t, heap, &expr.to_type)); self.kv(indent2).with_s_key("Subject"); @@ -776,21 +766,16 @@ impl ASTWriter { self.kv(indent).with_id(PREFIX_CALL_EXPR_ID, expr.this.0.index) .with_s_key("CallExpr"); - let definition = &heap[expr.definition]; - match definition { - Definition::Component(definition) => { - self.kv(indent2).with_s_key("BuiltIn").with_disp_val(&false); - self.kv(indent2).with_s_key("Variant").with_debug_val(&definition.variant); - }, - Definition::Function(definition) => { - self.kv(indent2).with_s_key("BuiltIn").with_disp_val(&definition.builtin); - self.kv(indent2).with_s_key("Variant").with_s_val("Function"); - }, - _ => unreachable!() + self.kv(indent2).with_s_key("TypeIndex").with_disp_val(&expr.type_index); + self.kv(indent2).with_s_key("Method").with_debug_val(&expr.method); + if !expr.procedure.is_invalid() { + let definition = &heap[expr.procedure]; + self.kv(indent2).with_s_key("BuiltIn").with_disp_val(&definition.builtin); + self.kv(indent2).with_s_key("Variant").with_debug_val(&definition.kind); + self.kv(indent2).with_s_key("MethodName").with_identifier_val(&definition.identifier); + self.kv(indent2).with_s_key("ParserType") + .with_custom_val(|t| write_parser_type(t, heap, &expr.parser_type)); } - self.kv(indent2).with_s_key("MethodName").with_identifier_val(definition.identifier()); - self.kv(indent2).with_s_key("ParserType") - .with_custom_val(|t| write_parser_type(t, heap, &expr.parser_type)); // Arguments self.kv(indent2).with_s_key("Arguments"); @@ -805,6 +790,7 @@ impl ASTWriter { Expression::Variable(expr) => { self.kv(indent).with_id(PREFIX_VARIABLE_EXPR_ID, expr.this.0.index) .with_s_key("VariableExpr"); + self.kv(indent2).with_s_key("TypeIndex").with_disp_val(&expr.type_index); self.kv(indent2).with_s_key("Name").with_identifier_val(&expr.identifier); self.kv(indent2).with_s_key("Definition") .with_opt_disp_val(expr.declaration.as_ref().map(|v| &v.index)); @@ -825,7 +811,7 @@ impl ASTWriter { self.kv(indent2).with_s_key("Kind").with_debug_val(&var.kind); self.kv(indent2).with_s_key("ParserType") .with_custom_val(|w| write_parser_type(w, heap, &var.parser_type)); - self.kv(indent2).with_s_key("RelativePos").with_disp_val(&var.relative_pos_in_block); + self.kv(indent2).with_s_key("RelativePos").with_disp_val(&var.relative_pos_in_parent); self.kv(indent2).with_s_key("UniqueScopeID").with_disp_val(&var.unique_id_in_scope); } @@ -854,6 +840,11 @@ fn write_option(target: &mut String, value: Option) { fn write_parser_type(target: &mut String, heap: &Heap, t: &ParserType) { use ParserTypeVariant as PTV; + if t.elements.is_empty() { + target.push_str("no elements in ParserType (can happen due to compiler-inserted AST nodes)"); + return; + } + fn write_element(target: &mut String, heap: &Heap, t: &ParserType, mut element_idx: usize) -> usize { let element = &t.elements[element_idx]; match &element.variant { @@ -950,7 +941,7 @@ fn write_concrete_type(target: &mut String, heap: &Heap, def_id: DefinitionId, t match &t.parts[idx] { CTP::Void => target.push_str("void"), CTP::Message => target.push_str("msg"), - CTP::Bool => target.push_str("bool"), + CTP::Bool => target.push_str(KW_TYPE_BOOL_STR), CTP::UInt8 => target.push_str(KW_TYPE_UINT8_STR), CTP::UInt16 => target.push_str(KW_TYPE_UINT16_STR), CTP::UInt32 => target.push_str(KW_TYPE_UINT32_STR), @@ -961,6 +952,7 @@ fn write_concrete_type(target: &mut String, heap: &Heap, def_id: DefinitionId, t CTP::SInt64 => target.push_str(KW_TYPE_SINT64_STR), CTP::Character => target.push_str(KW_TYPE_CHAR_STR), CTP::String => target.push_str(KW_TYPE_STRING_STR), + CTP::Pointer => target.push('*'), CTP::Array => { idx = write_concrete_part(target, heap, def_id, t, idx + 1); target.push_str("[]"); diff --git a/src/protocol/eval/error.rs b/src/protocol/eval/error.rs index fc194bb37c8ef31e8084d2613f4cdb9b9a36a344..81128b445a5c36ed8cda7034b2ca9d2f99084a8e 100644 --- a/src/protocol/eval/error.rs +++ b/src/protocol/eval/error.rs @@ -52,18 +52,8 @@ impl EvalError { let statement = &heap[frame.position]; let statement_span = statement.span(); - let (root_id, procedure, is_func) = match definition { - Definition::Function(def) => { - (def.defined_in, def.identifier.value.as_str().to_string(), true) - }, - Definition::Component(def) => { - (def.defined_in, def.identifier.value.as_str().to_string(), false) - }, - _ => unreachable!("construct stack frame with definition pointing to data type") - }; - // Lookup module name, if it has one - let module = modules.iter().find(|m| m.root_id == root_id).unwrap(); + let module = modules.iter().find(|m| m.root_id == definition.defined_in).unwrap(); let module_name = if let Some(name) = &module.name { name.as_str().to_string() } else { @@ -74,8 +64,8 @@ impl EvalError { frames.push(EvalFrame{ line: statement_span.begin.line, module_name, - procedure, - is_func + procedure: definition.identifier.value.as_str().to_string(), + is_func: definition.kind == ProcedureKind::Function, }); } diff --git a/src/protocol/eval/executor.rs b/src/protocol/eval/executor.rs index 00e18b174b087d514ebe4b29e8760a104de3c0bc..950826a2f2cf5d4412098f768ba442675b0b27d7 100644 --- a/src/protocol/eval/executor.rs +++ b/src/protocol/eval/executor.rs @@ -26,8 +26,9 @@ pub(crate) enum ExprInstruction { #[derive(Debug, Clone)] pub(crate) struct Frame { - pub(crate) definition: DefinitionId, - pub(crate) monomorph_idx: i32, + pub(crate) definition: ProcedureDefinitionId, + pub(crate) monomorph_type_id: TypeId, + pub(crate) monomorph_index: usize, pub(crate) position: StatementId, pub(crate) expr_stack: VecDeque, // hack for expression evaluation, evaluated by popping from back pub(crate) expr_values: VecDeque, // hack for expression results, evaluated by popping from front/back @@ -36,37 +37,34 @@ pub(crate) struct Frame { impl Frame { /// Creates a new execution frame. Does not modify the stack in any way. - pub fn new(heap: &Heap, definition_id: DefinitionId, monomorph_idx: i32) -> Self { + pub fn new(heap: &Heap, definition_id: ProcedureDefinitionId, monomorph_type_id: TypeId, monomorph_index: u32) -> Self { let definition = &heap[definition_id]; - let first_statement = match definition { - Definition::Component(definition) => definition.body, - Definition::Function(definition) => definition.body, - _ => unreachable!("initializing frame with {:?} instead of a function/component", definition), - }; + let outer_scope_id = definition.scope; + let first_statement_id = definition.body; // Another not-so-pretty thing that has to be replaced somewhere in the // future... - fn determine_max_stack_size(heap: &Heap, block_id: BlockStatementId, max_size: &mut u32) { - let block_stmt = &heap[block_id]; - debug_assert!(block_stmt.next_unique_id_in_scope >= 0); + fn determine_max_stack_size(heap: &Heap, scope_id: ScopeId, max_size: &mut u32) { + let scope = &heap[scope_id]; // Check current block - let cur_size = block_stmt.next_unique_id_in_scope as u32; + let cur_size = scope.next_unique_id_in_scope as u32; if cur_size > *max_size { *max_size = cur_size; } // And child blocks - for child_scope in &block_stmt.scope_node.nested { - determine_max_stack_size(heap, child_scope.to_block(), max_size); + for child_scope in &scope.nested { + determine_max_stack_size(heap, *child_scope, max_size); } } let mut max_stack_size = 0; - determine_max_stack_size(heap, first_statement, &mut max_stack_size); + determine_max_stack_size(heap, outer_scope_id, &mut max_stack_size); Frame{ definition: definition_id, - monomorph_idx, - position: first_statement.upcast(), + monomorph_type_id, + monomorph_index: monomorph_index as usize, + position: first_statement_id.upcast(), expr_stack: VecDeque::with_capacity(128), expr_values: VecDeque::with_capacity(128), max_stack_size, @@ -209,10 +207,13 @@ pub enum EvalContinuation { BlockFires(PortId), BlockGet(PortId), Put(PortId, ValueGroup), + SelectStart(u32, u32), // (num_cases, num_ports_total) + SelectRegisterPort(u32, u32, PortId), // (case_index, port_index_in_case, port_id) + SelectWait, // wait until select can continue // Returned only in non-sync mode ComponentTerminated, SyncBlockStart, - NewComponent(DefinitionId, i32, ValueGroup), + NewComponent(ProcedureDefinitionId, TypeId, ValueGroup), NewChannel, } @@ -225,14 +226,15 @@ pub struct Prompt { } impl Prompt { - pub fn new(_types: &TypeTable, heap: &Heap, def: DefinitionId, monomorph_idx: i32, args: ValueGroup) -> Self { + pub fn new(types: &TypeTable, heap: &Heap, def: ProcedureDefinitionId, type_id: TypeId, args: ValueGroup) -> Self { let mut prompt = Self{ frames: Vec::new(), store: Store::new(), }; // Maybe do typechecking in the future? - let new_frame = Frame::new(heap, def, monomorph_idx); + let monomorph_index = types.get_monomorph(type_id).variant.as_procedure().monomorph_index; + let new_frame = Frame::new(heap, def, type_id, monomorph_index); let max_stack_size = new_frame.max_stack_size; prompt.frames.push(new_frame); args.into_store(&mut prompt.store); @@ -292,13 +294,13 @@ impl Prompt { // Checking if we're at the end of execution let cur_frame = self.frames.last_mut().unwrap(); if cur_frame.position.is_invalid() { - if heap[cur_frame.definition].is_function() { + if heap[cur_frame.definition].kind == ProcedureKind::Function { todo!("End of function without return, return an evaluation error"); } return Ok(EvalContinuation::ComponentTerminated); } - debug_log!("Taking step in '{}'", heap[cur_frame.definition].identifier().value.as_str()); + debug_log!("Taking step in '{}'", heap[cur_frame.definition].identifier.value.as_str()); // Execute all pending expressions while !cur_frame.expr_stack.is_empty() { @@ -480,8 +482,8 @@ impl Prompt { }, Expression::Select(expr) => { let subject= cur_frame.expr_values.pop_back().unwrap(); - let mono_data = types.get_procedure_monomorph(cur_frame.monomorph_idx); - let field_idx = mono_data.expr_data[expr.unique_id_in_definition as usize].field_or_monomorph_idx as u32; + let mono_data = &heap[cur_frame.definition].monomorphs[cur_frame.monomorph_index]; + let field_idx = mono_data.expr_info[expr.type_index as usize].variant.as_select() as u32; // Note: same as above: clone if value lives on expr stack, simply // refer to it if it already lives on the stack/heap. @@ -528,8 +530,9 @@ impl Prompt { } Literal::Integer(lit_value) => { use ConcreteTypePart as CTP; - let def_types = types.get_procedure_monomorph(cur_frame.monomorph_idx); - let concrete_type = &def_types.expr_data[expr.unique_id_in_definition as usize].expr_type; + let mono_data = &heap[cur_frame.definition].monomorphs[cur_frame.monomorph_index]; + let type_id = mono_data.expr_info[expr.type_index as usize].type_id; + let concrete_type = &types.get_monomorph(type_id).concrete_type; debug_assert_eq!(concrete_type.parts.len(), 1); match concrete_type.parts[0] { @@ -576,14 +579,15 @@ impl Prompt { cur_frame.expr_values.push_back(value); }, Expression::Cast(expr) => { - let mono_data = types.get_procedure_monomorph(cur_frame.monomorph_idx); - let output_type = &mono_data.expr_data[expr.unique_id_in_definition as usize].expr_type; + let mono_data = &heap[cur_frame.definition].monomorphs[cur_frame.monomorph_index]; + let type_id = mono_data.expr_info[expr.type_index as usize].type_id; + let concrete_type = &types.get_monomorph(type_id).concrete_type; // Typechecking reduced this to two cases: either we // have casting noop (same types), or we're casting // between integer/bool/char types. let subject = cur_frame.expr_values.pop_back().unwrap(); - match apply_casting(&mut self.store, output_type, &subject) { + match apply_casting(&mut self.store, concrete_type, &subject) { Ok(value) => cur_frame.expr_values.push_back(value), Err(msg) => { return Err(EvalError::new_error_at_expr(self, modules, heap, expr.this.upcast(), msg)); @@ -651,12 +655,7 @@ impl Prompt { Method::Fires => { let port_value = cur_frame.expr_values.pop_front().unwrap(); let port_value_deref = self.store.maybe_read_ref(&port_value).clone(); - - let port_id = match port_value_deref { - Value::Input(port_id) => port_id, - Value::Output(port_id) => port_id, - _ => unreachable!("executor calling 'fires' on value {:?}", port_value_deref), - }; + let port_id = port_value_deref.as_port_id(); match ctx.fires(port_id) { None => { @@ -734,10 +733,34 @@ impl Prompt { self.store.drop_heap_pos(value_heap_pos); println!("{}", message); }, + Method::SelectStart => { + let num_cases = self.store.maybe_read_ref(&cur_frame.expr_values.pop_front().unwrap()).as_uint32(); + let num_ports = self.store.maybe_read_ref(&cur_frame.expr_values.pop_front().unwrap()).as_uint32(); + + return Ok(EvalContinuation::SelectStart(num_cases, num_ports)); + }, + Method::SelectRegisterCasePort => { + let case_index = self.store.maybe_read_ref(&cur_frame.expr_values.pop_front().unwrap()).as_uint32(); + let port_index = self.store.maybe_read_ref(&cur_frame.expr_values.pop_front().unwrap()).as_uint32(); + let port_value = self.store.maybe_read_ref(&cur_frame.expr_values.pop_front().unwrap()).as_port_id(); + + return Ok(EvalContinuation::SelectRegisterPort(case_index, port_index, port_value)); + }, + Method::SelectWait => { + match ctx.performed_select_wait() { + Some(select_index) => { + cur_frame.expr_values.push_back(Value::UInt32(select_index)); + }, + None => { + cur_frame.expr_stack.push_back(ExprInstruction::EvalExpr(expr.this.upcast())); + return Ok(EvalContinuation::SelectWait) + }, + } + }, Method::UserComponent => { // This is actually handled by the evaluation // of the statement. - debug_assert_eq!(heap[expr.definition].parameters().len(), cur_frame.expr_values.len()); + debug_assert_eq!(heap[expr.procedure].parameters.len(), cur_frame.expr_values.len()); debug_assert_eq!(heap[cur_frame.position].as_new().expression, expr.this) }, Method::UserFunction => { @@ -758,11 +781,11 @@ impl Prompt { } // Determine the monomorph index of the function we're calling - let mono_data = types.get_procedure_monomorph(cur_frame.monomorph_idx); - let call_data = &mono_data.expr_data[expr.unique_id_in_definition as usize]; + let mono_data = &heap[cur_frame.definition].monomorphs[cur_frame.monomorph_index]; + let (type_id, monomorph_index) = mono_data.expr_info[expr.type_index as usize].variant.as_procedure(); // Push the new frame and reserve its stack size - let new_frame = Frame::new(heap, expr.definition, call_data.field_or_monomorph_idx); + let new_frame = Frame::new(heap, expr.procedure, type_id, monomorph_index); let new_stack_size = new_frame.max_stack_size; self.frames.push(new_frame); self.store.cur_stack_boundary = new_stack_boundary; @@ -771,7 +794,7 @@ impl Prompt { // To simplify the logic a little bit we will now // return and ask our caller to call us again return Ok(EvalContinuation::Stepping); - }, + } } }, Expression::Variable(expr) => { @@ -817,7 +840,8 @@ impl Prompt { }, Statement::EndBlock(stmt) => { let block = &heap[stmt.start_block]; - self.store.clear_stack(block.first_unique_id_in_scope as usize); + let scope = &heap[block.scope]; + self.store.clear_stack(scope.first_unique_id_in_scope as usize); cur_frame.position = stmt.next; Ok(EvalContinuation::Stepping) @@ -825,13 +849,13 @@ impl Prompt { Statement::Local(stmt) => { match stmt { LocalStatement::Memory(stmt) => { - if cfg!(debug_assertions) { + dbg_code!({ let variable = &heap[stmt.variable]; debug_assert!(match self.store.read_ref(ValueId::Stack(variable.unique_id_in_scope as u32)) { Value::Unassigned => false, _ => true, }); - } + }); cur_frame.position = stmt.next; Ok(EvalContinuation::Stepping) @@ -842,7 +866,7 @@ impl Prompt { match ctx.created_channel() { None => { // No channel is pending. So request one - Ok(EvalContinuation::NewChannel) + Ok(EvalContinuation::NewChannel) }, Some((put_port, get_port)) => { self.store.write(ValueId::Stack(heap[stmt.from].unique_id_in_scope as u32), put_port); @@ -864,9 +888,9 @@ impl Prompt { let test_value = cur_frame.expr_values.pop_back().unwrap(); let test_value = self.store.maybe_read_ref(&test_value).as_bool(); if test_value { - cur_frame.position = stmt.true_body.upcast(); - } else if let Some(false_body) = stmt.false_body { - cur_frame.position = false_body.upcast(); + cur_frame.position = stmt.true_case.body; + } else if let Some(false_body) = stmt.false_case { + cur_frame.position = false_body.body; } else { // Not true, and no false body cur_frame.position = stmt.end_if.upcast(); @@ -876,6 +900,13 @@ impl Prompt { }, Statement::EndIf(stmt) => { cur_frame.position = stmt.next; + let if_stmt = &heap[stmt.start_if]; + debug_assert_eq!( + heap[if_stmt.true_case.scope].first_unique_id_in_scope, + heap[if_stmt.false_case.unwrap_or(if_stmt.true_case).scope].first_unique_id_in_scope, + ); + let scope = &heap[if_stmt.true_case.scope]; + self.store.clear_stack(scope.first_unique_id_in_scope as usize); Ok(EvalContinuation::Stepping) }, Statement::While(stmt) => { @@ -883,7 +914,7 @@ impl Prompt { let test_value = cur_frame.expr_values.pop_back().unwrap(); let test_value = self.store.maybe_read_ref(&test_value).as_bool(); if test_value { - cur_frame.position = stmt.body.upcast(); + cur_frame.position = stmt.body; } else { cur_frame.position = stmt.end_while.upcast(); } @@ -892,7 +923,9 @@ impl Prompt { }, Statement::EndWhile(stmt) => { cur_frame.position = stmt.next; - + let start_while = &heap[stmt.start_while]; + let scope = &heap[start_while.scope]; + self.store.clear_stack(scope.first_unique_id_in_scope as usize); Ok(EvalContinuation::Stepping) }, Statement::Break(stmt) => { @@ -906,27 +939,30 @@ impl Prompt { Ok(EvalContinuation::Stepping) }, Statement::Synchronous(stmt) => { - cur_frame.position = stmt.body.upcast(); + cur_frame.position = stmt.body; Ok(EvalContinuation::SyncBlockStart) }, Statement::EndSynchronous(stmt) => { cur_frame.position = stmt.next; + let start_synchronous = &heap[stmt.start_sync]; + let scope = &heap[start_synchronous.scope]; + self.store.clear_stack(scope.first_unique_id_in_scope as usize); Ok(EvalContinuation::SyncBlockEnd) }, Statement::Fork(stmt) => { if stmt.right_body.is_none() { // No reason to fork - cur_frame.position = stmt.left_body.upcast(); + cur_frame.position = stmt.left_body; } else { // Need to fork if let Some(go_left) = ctx.performed_fork() { // Runtime has created a fork if go_left { - cur_frame.position = stmt.left_body.upcast(); + cur_frame.position = stmt.left_body; } else { - cur_frame.position = stmt.right_body.unwrap().upcast(); + cur_frame.position = stmt.right_body.unwrap(); } } else { // Request the runtime to create a fork of the current @@ -942,16 +978,24 @@ impl Prompt { Ok(EvalContinuation::Stepping) }, - Statement::Select(_stmt) => { - todo!("implement select evaluation") + Statement::Select(stmt) => { + // This is a trampoline for the statements that were placed by + // the AST transformation pass + cur_frame.position = stmt.next; + + Ok(EvalContinuation::Stepping) }, Statement::EndSelect(stmt) => { cur_frame.position = stmt.next; + let start_select = &heap[stmt.start_select]; + if let Some(select_case) = start_select.cases.first() { + let scope = &heap[select_case.scope]; + self.store.clear_stack(scope.first_unique_id_in_scope as usize); + } Ok(EvalContinuation::Stepping) }, Statement::Return(_stmt) => { - debug_assert!(heap[cur_frame.definition].is_function()); debug_assert_eq!(cur_frame.expr_values.len(), 1, "expected one expr value for return statement"); // The preceding frame has executed a call, so is expecting the @@ -998,14 +1042,13 @@ impl Prompt { }, Statement::New(stmt) => { let call_expr = &heap[stmt.expression]; - debug_assert!(heap[call_expr.definition].is_component()); debug_assert_eq!( - cur_frame.expr_values.len(), heap[call_expr.definition].parameters().len(), + cur_frame.expr_values.len(), heap[call_expr.procedure].parameters.len(), "mismatch in expr stack size and number of arguments for new statement" ); - let mono_data = types.get_procedure_monomorph(cur_frame.monomorph_idx); - let expr_data = &mono_data.expr_data[call_expr.unique_id_in_definition as usize]; + let mono_data = &heap[cur_frame.definition].monomorphs[cur_frame.monomorph_index]; + let type_id = mono_data.expr_info[call_expr.type_index as usize].variant.as_procedure().0; // Note that due to expression value evaluation they exist in // reverse order on the stack. @@ -1017,7 +1060,6 @@ impl Prompt { // Construct argument group, thereby copying heap regions let argument_group = ValueGroup::from_store(&self.store, &args); - // println!("Creating {} with\n{:#?}", heap[call_expr.definition].identifier().value.as_str(), argument_group); // Clear any heap regions for arg in &args { @@ -1026,7 +1068,7 @@ impl Prompt { cur_frame.position = stmt.next; - Ok(EvalContinuation::NewComponent(call_expr.definition, expr_data.field_or_monomorph_idx, argument_group)) + Ok(EvalContinuation::NewComponent(call_expr.procedure, type_id, argument_group)) }, Statement::Expression(stmt) => { // The expression has just been completely evaluated. Some diff --git a/src/protocol/eval/value.rs b/src/protocol/eval/value.rs index b6c8b42c3ca0ec60c9cbcc702bdd854b4cb5fc35..d8bf773b7bc74426a37fa54ad573c4c0d6d8bd00 100644 --- a/src/protocol/eval/value.rs +++ b/src/protocol/eval/value.rs @@ -65,39 +65,26 @@ pub enum Value { Struct(HeapPos), } -macro_rules! impl_union_unpack_as_value { - ($func_name:ident, $variant_name:path, $return_type:ty) => { - impl Value { - pub(crate) fn $func_name(&self) -> $return_type { - match self { - $variant_name(v) => *v, - _ => panic!(concat!("called ", stringify!($func_name()), " on {:?}"), self), - } - } - } - } -} - -impl_union_unpack_as_value!(as_stack_boundary, Value::PrevStackBoundary, isize); -impl_union_unpack_as_value!(as_ref, Value::Ref, ValueId); -impl_union_unpack_as_value!(as_input, Value::Input, PortId); -impl_union_unpack_as_value!(as_output, Value::Output, PortId); -impl_union_unpack_as_value!(as_message, Value::Message, HeapPos); -impl_union_unpack_as_value!(as_bool, Value::Bool, bool); -impl_union_unpack_as_value!(as_char, Value::Char, char); -impl_union_unpack_as_value!(as_string, Value::String, HeapPos); -impl_union_unpack_as_value!(as_uint8, Value::UInt8, u8); -impl_union_unpack_as_value!(as_uint16, Value::UInt16, u16); -impl_union_unpack_as_value!(as_uint32, Value::UInt32, u32); -impl_union_unpack_as_value!(as_uint64, Value::UInt64, u64); -impl_union_unpack_as_value!(as_sint8, Value::SInt8, i8); -impl_union_unpack_as_value!(as_sint16, Value::SInt16, i16); -impl_union_unpack_as_value!(as_sint32, Value::SInt32, i32); -impl_union_unpack_as_value!(as_sint64, Value::SInt64, i64); -impl_union_unpack_as_value!(as_array, Value::Array, HeapPos); -impl_union_unpack_as_value!(as_tuple, Value::Tuple, HeapPos); -impl_union_unpack_as_value!(as_enum, Value::Enum, i64); -impl_union_unpack_as_value!(as_struct, Value::Struct, HeapPos); +union_cast_to_value_method_impl!(as_stack_boundary, isize, Value::PrevStackBoundary); +union_cast_to_value_method_impl!(as_ref, ValueId, Value::Ref); +union_cast_to_value_method_impl!(as_input, PortId, Value::Input); +union_cast_to_value_method_impl!(as_output, PortId, Value::Output); +union_cast_to_value_method_impl!(as_message, HeapPos, Value::Message); +union_cast_to_value_method_impl!(as_bool, bool, Value::Bool); +union_cast_to_value_method_impl!(as_char, char, Value::Char); +union_cast_to_value_method_impl!(as_string, HeapPos, Value::String); +union_cast_to_value_method_impl!(as_uint8, u8, Value::UInt8); +union_cast_to_value_method_impl!(as_uint16, u16, Value::UInt16); +union_cast_to_value_method_impl!(as_uint32, u32, Value::UInt32); +union_cast_to_value_method_impl!(as_uint64, u64, Value::UInt64); +union_cast_to_value_method_impl!(as_sint8, i8, Value::SInt8); +union_cast_to_value_method_impl!(as_sint16, i16, Value::SInt16); +union_cast_to_value_method_impl!(as_sint32, i32, Value::SInt32); +union_cast_to_value_method_impl!(as_sint64, i64, Value::SInt64); +union_cast_to_value_method_impl!(as_array, HeapPos, Value::Array); +union_cast_to_value_method_impl!(as_tuple, HeapPos, Value::Tuple); +union_cast_to_value_method_impl!(as_enum, i64, Value::Enum); +union_cast_to_value_method_impl!(as_struct, HeapPos, Value::Struct); impl Value { pub(crate) fn as_union(&self) -> (i64, HeapPos) { @@ -107,6 +94,14 @@ impl Value { } } + pub(crate) fn as_port_id(&self) -> PortId { + match self { + Value::Input(v) => *v, + Value::Output(v) => *v, + _ => unreachable!(), + } + } + pub(crate) fn is_integer(&self) -> bool { match self { Value::UInt8(_) | Value::UInt16(_) | Value::UInt32(_) | Value::UInt64(_) | @@ -520,7 +515,7 @@ pub(crate) fn apply_unary_operator(store: &mut Store, op: UnaryOperator, value: Value::SInt32(v) => Value::SInt32($apply *v), Value::SInt64(v) => Value::SInt64($apply *v), _ => unreachable!("apply_unary_operator {:?} on value {:?}", $op, $value), - }; + } } } @@ -542,7 +537,7 @@ pub(crate) fn apply_unary_operator(store: &mut Store, op: UnaryOperator, value: _ => unreachable!("apply_unary_operator {:?} on value {:?}", op, value), } }, - UO::BitwiseNot => { apply_int_expr_and_return!(value, !, op)}, + UO::BitwiseNot => { apply_int_expr_and_return!(value, !, op); }, UO::LogicalNot => { return Value::Bool(!value.as_bool()); }, } } diff --git a/src/protocol/input_source.rs b/src/protocol/input_source.rs index fccb0372e7697ff8478444038d3fad561cb7e0f1..f2cf10693ae87bc134610829d91344bcc1f221f1 100644 --- a/src/protocol/input_source.rs +++ b/src/protocol/input_source.rs @@ -21,9 +21,10 @@ pub struct InputSpan { } impl InputSpan { - // This will only be used for builtin functions + // This must only be used if you're sure that the span will not be involved + // in creating an error message. #[inline] - pub fn new() -> InputSpan { + pub const fn new() -> InputSpan { InputSpan{ begin: InputPosition{ line: 0, offset: 0 }, end: InputPosition{ line: 0, offset: 0 }} } diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index c7c4ba07d617734fe1aaab99fcbe75c5634f550b..45f880fbe64dec700ef72086e324877299969e5b 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -16,6 +16,8 @@ use crate::protocol::input_source::*; use crate::protocol::parser::*; use crate::protocol::type_table::*; +pub use parser::type_table::TypeId; + /// A protocol description module pub struct Module { pub(crate) source: InputSource, @@ -92,29 +94,31 @@ impl ProtocolDescription { let definition_id = definition_id.unwrap(); let ast_definition = &self.heap[definition_id]; - if !ast_definition.is_component() { + if !ast_definition.is_procedure() { return Err(ComponentCreationError::DefinitionNotComponent); } // Make sure that the types of the provided value group matches that of // the expected types. - let ast_definition = ast_definition.as_component(); - if !ast_definition.poly_vars.is_empty() { + let ast_definition = ast_definition.as_procedure(); + if !ast_definition.poly_vars.is_empty() || ast_definition.kind == ProcedureKind::Function { return Err(ComponentCreationError::DefinitionNotComponent); } // - check number of arguments by retrieving the one instantiated // monomorph - let concrete_type = ConcreteType{ parts: vec![ConcreteTypePart::Component(definition_id, 0)] }; - let mono_index = self.types.get_procedure_monomorph_index(&definition_id, &concrete_type.parts).unwrap(); - let mono_type = self.types.get_procedure_monomorph(mono_index); - if mono_type.arg_types.len() != arguments.values.len() { + let concrete_type = ConcreteType{ parts: vec![ConcreteTypePart::Component(ast_definition.this, 0)] }; + let procedure_type_id = self.types.get_procedure_monomorph_type_id(&definition_id, &concrete_type.parts).unwrap(); + let procedure_monomorph_index = self.types.get_monomorph(procedure_type_id).variant.as_procedure().monomorph_index; + let monomorph_info = &ast_definition.monomorphs[procedure_monomorph_index as usize]; + if monomorph_info.argument_types.len() != arguments.values.len() { return Err(ComponentCreationError::InvalidNumArguments); } // - for each argument try to make sure the types match for arg_idx in 0..arguments.values.len() { - let expected_type = &mono_type.arg_types[arg_idx]; + let expected_type_id = monomorph_info.argument_types[arg_idx]; + let expected_type = &self.types.get_monomorph(expected_type_id).concrete_type; let provided_value = &arguments.values[arg_idx]; if !self.verify_same_type(expected_type, 0, &arguments, provided_value) { return Err(ComponentCreationError::InvalidArgumentType(arg_idx)); @@ -123,7 +127,7 @@ impl ProtocolDescription { // By now we're sure that all of the arguments are correct. So create // the connector. - return Ok(Prompt::new(&self.types, &self.heap, definition_id, mono_index, arguments)); + return Ok(Prompt::new(&self.types, &self.heap, ast_definition.this, procedure_type_id, arguments)); } fn lookup_module_root(&self, module_name: &[u8]) -> Option { @@ -145,7 +149,7 @@ impl ProtocolDescription { use ConcreteTypePart as CTP; match &expected.parts[expected_idx] { - CTP::Void | CTP::Message | CTP::Slice | CTP::Function(_, _) | CTP::Component(_, _) => unreachable!(), + CTP::Void | CTP::Message | CTP::Slice | CTP::Pointer | CTP::Function(_, _) | CTP::Component(_, _) => unreachable!(), CTP::Bool => if let Value::Bool(_) = argument { true } else { false }, CTP::UInt8 => if let Value::UInt8(_) = argument { true } else { false }, CTP::UInt16 => if let Value::UInt16(_) = argument { true } else { false }, @@ -211,6 +215,7 @@ pub trait RunContext { fn fires(&mut self, port: PortId) -> Option; // None if not yet branched fn performed_fork(&mut self) -> Option; // None if not yet forked fn created_channel(&mut self) -> Option<(Value, Value)>; // None if not yet prepared + fn performed_select_wait(&mut self) -> Option; // None if not yet notified runtime of select blocker } pub struct ProtocolDescriptionBuilder { diff --git a/src/protocol/parser/mod.rs b/src/protocol/parser/mod.rs index 935ccdcc7d7b7b3e50bc31fbeb2cef2f96d7faff..e55bf6b31adb04d40e19891ed39a219762f616dd 100644 --- a/src/protocol/parser/mod.rs +++ b/src/protocol/parser/mod.rs @@ -1,3 +1,4 @@ +#[macro_use] mod visitor; pub(crate) mod symbol_table; pub(crate) mod type_table; pub(crate) mod tokens; @@ -8,8 +9,9 @@ pub(crate) mod pass_imports; pub(crate) mod pass_definitions; pub(crate) mod pass_definitions_types; pub(crate) mod pass_validation_linking; +pub(crate) mod pass_rewriting; pub(crate) mod pass_typing; -mod visitor; +pub(crate) mod pass_stack_size; use tokens::*; use crate::collections::*; @@ -20,13 +22,16 @@ use pass_imports::PassImport; use pass_definitions::PassDefinitions; use pass_validation_linking::PassValidationLinking; use pass_typing::{PassTyping, ResolveQueue}; +use pass_rewriting::PassRewriting; +use pass_stack_size::PassStackSize; use symbol_table::*; -use type_table::TypeTable; +use type_table::*; use crate::protocol::ast::*; use crate::protocol::input_source::*; use crate::protocol::ast_printer::ASTWriter; +use crate::protocol::parser::type_table::PolymorphicVariable; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] pub enum ModuleCompilationPhase { @@ -36,8 +41,10 @@ pub enum ModuleCompilationPhase { DefinitionsParsed, // produced the AST for the entire module TypesAddedToTable, // added all definitions to the type table ValidatedAndLinked, // AST is traversed and has linked the required AST nodes + Typed, // Type inference and checking has been performed + Rewritten, // Special AST nodes are rewritten into regular AST nodes // When we continue with the compiler: - // Typed, // Type inference and checking has been performed + // StackSize } pub struct Module { @@ -51,16 +58,50 @@ pub struct Module { pub phase: ModuleCompilationPhase, } -// TODO: This is kind of wrong. Because when we're producing bytecode we would -// like the bytecode itself to not have the notion of the size of a pointer -// type. But until I figure out what we do want I'll just set everything -// to a 64-bit architecture. pub struct TargetArch { - pub array_size_alignment: (usize, usize), - pub slice_size_alignment: (usize, usize), - pub string_size_alignment: (usize, usize), - pub port_size_alignment: (usize, usize), - pub pointer_size_alignment: (usize, usize), + pub void_type_id: TypeId, + pub message_type_id: TypeId, + pub bool_type_id: TypeId, + pub uint8_type_id: TypeId, + pub uint16_type_id: TypeId, + pub uint32_type_id: TypeId, + pub uint64_type_id: TypeId, + pub sint8_type_id: TypeId, + pub sint16_type_id: TypeId, + pub sint32_type_id: TypeId, + pub sint64_type_id: TypeId, + pub char_type_id: TypeId, + pub string_type_id: TypeId, + pub array_type_id: TypeId, + pub slice_type_id: TypeId, + pub input_type_id: TypeId, + pub output_type_id: TypeId, + pub pointer_type_id: TypeId, +} + +impl TargetArch { + fn new() -> Self { + return Self{ + void_type_id: TypeId::new_invalid(), + bool_type_id: TypeId::new_invalid(), + message_type_id: TypeId::new_invalid(), + uint8_type_id: TypeId::new_invalid(), + uint16_type_id: TypeId::new_invalid(), + uint32_type_id: TypeId::new_invalid(), + uint64_type_id: TypeId::new_invalid(), + sint8_type_id: TypeId::new_invalid(), + sint16_type_id: TypeId::new_invalid(), + sint32_type_id: TypeId::new_invalid(), + sint64_type_id: TypeId::new_invalid(), + char_type_id: TypeId::new_invalid(), + string_type_id: TypeId::new_invalid(), + array_type_id: TypeId::new_invalid(), + slice_type_id: TypeId::new_invalid(), + input_type_id: TypeId::new_invalid(), + output_type_id: TypeId::new_invalid(), + pointer_type_id: TypeId::new_invalid(), + } + } } pub struct PassCtx<'a> { @@ -85,6 +126,8 @@ pub struct Parser { pass_definitions: PassDefinitions, pass_validation: PassValidationLinking, pass_typing: PassTyping, + pass_rewriting: PassRewriting, + pass_stack_size: PassStackSize, // Compiler options pub write_ast_to: Option, pub(crate) arch: TargetArch, @@ -104,18 +147,36 @@ impl Parser { pass_definitions: PassDefinitions::new(), pass_validation: PassValidationLinking::new(), pass_typing: PassTyping::new(), + pass_rewriting: PassRewriting::new(), + pass_stack_size: PassStackSize::new(), write_ast_to: None, - arch: TargetArch { - array_size_alignment: (3*8, 8), // pointer, length, capacity - slice_size_alignment: (2*8, 8), // pointer, length - string_size_alignment: (3*8, 8), // pointer, length, capacity - port_size_alignment: (3*4, 4), // two u32s: connector + port ID - pointer_size_alignment: (8, 8), - } + arch: TargetArch::new(), }; parser.symbol_table.insert_scope(None, SymbolScope::Global); + // Insert builtin types + // TODO: At some point use correct values for size/alignment + parser.arch.void_type_id = insert_builtin_type(&mut parser.type_table, vec![ConcreteTypePart::Void], false, 0, 1); + parser.arch.message_type_id = insert_builtin_type(&mut parser.type_table, vec![ConcreteTypePart::Message], false, 24, 8); + parser.arch.bool_type_id = insert_builtin_type(&mut parser.type_table, vec![ConcreteTypePart::Bool], false, 1, 1); + parser.arch.uint8_type_id = insert_builtin_type(&mut parser.type_table, vec![ConcreteTypePart::UInt8], false, 1, 1); + parser.arch.uint16_type_id = insert_builtin_type(&mut parser.type_table, vec![ConcreteTypePart::UInt16], false, 2, 2); + parser.arch.uint32_type_id = insert_builtin_type(&mut parser.type_table, vec![ConcreteTypePart::UInt32], false, 4, 4); + parser.arch.uint64_type_id = insert_builtin_type(&mut parser.type_table, vec![ConcreteTypePart::UInt64], false, 8, 8); + parser.arch.sint8_type_id = insert_builtin_type(&mut parser.type_table, vec![ConcreteTypePart::SInt8], false, 1, 1); + parser.arch.sint16_type_id = insert_builtin_type(&mut parser.type_table, vec![ConcreteTypePart::SInt16], false, 2, 2); + parser.arch.sint32_type_id = insert_builtin_type(&mut parser.type_table, vec![ConcreteTypePart::SInt32], false, 4, 4); + parser.arch.sint64_type_id = insert_builtin_type(&mut parser.type_table, vec![ConcreteTypePart::SInt64], false, 8, 8); + parser.arch.char_type_id = insert_builtin_type(&mut parser.type_table, vec![ConcreteTypePart::Character], false, 4, 4); + parser.arch.string_type_id = insert_builtin_type(&mut parser.type_table, vec![ConcreteTypePart::String], false, 24, 8); + parser.arch.array_type_id = insert_builtin_type(&mut parser.type_table, vec![ConcreteTypePart::Array, ConcreteTypePart::Void], true, 24, 8); + parser.arch.slice_type_id = insert_builtin_type(&mut parser.type_table, vec![ConcreteTypePart::Slice, ConcreteTypePart::Void], true, 16, 4); + parser.arch.input_type_id = insert_builtin_type(&mut parser.type_table, vec![ConcreteTypePart::Input, ConcreteTypePart::Void], true, 8, 8); + parser.arch.output_type_id = insert_builtin_type(&mut parser.type_table, vec![ConcreteTypePart::Output, ConcreteTypePart::Void], true, 8, 8); + parser.arch.pointer_type_id = insert_builtin_type(&mut parser.type_table, vec![ConcreteTypePart::Pointer, ConcreteTypePart::Void], true, 8, 8); + + // Insert builtin functions fn quick_type(variants: &[ParserTypeVariant]) -> ParserType { let mut t = ParserType{ elements: Vec::with_capacity(variants.len()), full_span: InputSpan::new() }; for variant in variants { @@ -237,10 +298,10 @@ impl Parser { types: &mut self.type_table, arch: &self.arch, }; - PassTyping::queue_module_definitions(&mut ctx, &mut queue); + self.pass_typing.queue_module_definitions(&mut ctx, &mut queue); }; while !queue.is_empty() { - let top = queue.pop().unwrap(); + let top = queue.pop_front().unwrap(); let mut ctx = visitor::Ctx{ heap: &mut self.heap, modules: &mut self.modules, @@ -252,6 +313,21 @@ impl Parser { self.pass_typing.handle_module_definition(&mut ctx, &mut queue, top)?; } + // Rewrite nodes in tree, then prepare for execution of code + for module_idx in 0..self.modules.len() { + self.modules[module_idx].phase = ModuleCompilationPhase::Typed; + let mut ctx = visitor::Ctx{ + heap: &mut self.heap, + modules: &mut self.modules, + module_idx, + symbols: &mut self.symbol_table, + types: &mut self.type_table, + arch: &self.arch, + }; + self.pass_rewriting.visit_module(&mut ctx)?; + self.pass_stack_size.visit_module(&mut ctx)?; + } + // Write out desired information if let Some(filename) = &self.write_ast_to { let mut writer = ASTWriter::new(); @@ -263,49 +339,74 @@ impl Parser { } } -// Note: args and return type need to be a function because we need to know the function ID. -fn insert_builtin_function (Vec<(&'static str, ParserType)>, ParserType)> ( - p: &mut Parser, func_name: &str, polymorphic: &[&str], arg_and_return_fn: T) { +fn insert_builtin_type(type_table: &mut TypeTable, parts: Vec, has_poly_var: bool, size: usize, alignment: usize) -> TypeId { + const POLY_VARS: [PolymorphicVariable; 1] = [PolymorphicVariable{ + identifier: Identifier::new_empty(InputSpan::new()), + is_in_use: false, + }]; - let mut poly_vars = Vec::with_capacity(polymorphic.len()); + let concrete_type = ConcreteType{ parts }; + let poly_var = if has_poly_var { + POLY_VARS.as_slice() + } else { + &[] + }; + + return type_table.add_builtin_data_type(concrete_type, poly_var, size, alignment); +} + +// Note: args and return type need to be a function because we need to know the function ID. +fn insert_builtin_function (Vec<(&'static str, ParserType)>, ParserType)> ( + p: &mut Parser, func_name: &str, polymorphic: &[&str], arg_and_return_fn: T +) { + // Insert into AST (to get an ID), also prepare the polymorphic variables + // we need later for the type table + let mut ast_poly_vars = Vec::with_capacity(polymorphic.len()); + let mut type_poly_vars = Vec::with_capacity(polymorphic.len()); for poly_var in polymorphic { - poly_vars.push(Identifier{ span: InputSpan::new(), value: p.string_pool.intern(poly_var.as_bytes()) }); + let identifier = Identifier{ span: InputSpan::new(), value: p.string_pool.intern(poly_var.as_bytes()) } ; + ast_poly_vars.push(identifier.clone()); + type_poly_vars.push(PolymorphicVariable{ identifier, is_in_use: false }); } let func_ident_ref = p.string_pool.intern(func_name.as_bytes()); - let func_id = p.heap.alloc_function_definition(|this| FunctionDefinition{ + let procedure_id = p.heap.alloc_procedure_definition(|this| ProcedureDefinition { this, defined_in: RootId::new_invalid(), builtin: true, + kind: ProcedureKind::Function, span: InputSpan::new(), identifier: Identifier{ span: InputSpan::new(), value: func_ident_ref.clone() }, - poly_vars, - return_types: Vec::new(), + poly_vars: ast_poly_vars, + return_type: None, parameters: Vec::new(), + scope: ScopeId::new_invalid(), body: BlockStatementId::new_invalid(), - num_expressions_in_body: -1, + monomorphs: Vec::new(), }); - let (args, ret) = arg_and_return_fn(func_id); + // Modify AST with more information about the procedure + let (arguments, return_type) = arg_and_return_fn(procedure_id); - let mut parameters = Vec::with_capacity(args.len()); - for (arg_name, arg_type) in args { + let mut parameters = Vec::with_capacity(arguments.len()); + for (arg_name, arg_type) in arguments { let identifier = Identifier{ span: InputSpan::new(), value: p.string_pool.intern(arg_name.as_bytes()) }; let param_id = p.heap.alloc_variable(|this| Variable{ this, kind: VariableKind::Parameter, parser_type: arg_type.clone(), identifier, - relative_pos_in_block: 0, + relative_pos_in_parent: 0, unique_id_in_scope: 0 }); parameters.push(param_id); } - let func = &mut p.heap[func_id]; + let func = &mut p.heap[procedure_id]; func.parameters = parameters; - func.return_types.push(ret); + func.return_type = Some(return_type); + // Insert into symbol table p.symbol_table.insert_symbol(SymbolScope::Global, Symbol{ name: func_ident_ref, variant: SymbolVariant::Definition(SymbolDefinition{ @@ -315,7 +416,16 @@ fn insert_builtin_function (Vec<(&'static str, Pa identifier_span: InputSpan::new(), imported_at: None, class: DefinitionClass::Function, - definition_id: func_id.upcast(), + definition_id: procedure_id.upcast(), }) }).unwrap(); + + // Insert into type table + // let mut concrete_type = ConcreteType::default(); + // concrete_type.parts.push(ConcreteTypePart::Function(procedure_id, type_poly_vars.len() as u32)); + // + // for _ in 0..type_poly_vars.len() { + // concrete_type.parts.push(ConcreteTypePart::Void); // doesn't matter (I hope...) + // } + // p.type_table.add_builtin_procedure_type(concrete_type, &type_poly_vars); } \ No newline at end of file diff --git a/src/protocol/parser/pass_definitions.rs b/src/protocol/parser/pass_definitions.rs index 66ed638cc16da38f731d12225efd5d7c0a7da217..8ce576bcfa720aef0e2e4d9701150fb216f14b7a 100644 --- a/src/protocol/parser/pass_definitions.rs +++ b/src/protocol/parser/pass_definitions.rs @@ -273,35 +273,22 @@ impl PassDefinitions { // Consume return types consume_token(&module.source, iter, TokenKind::ArrowRight)?; - let mut return_types = self.parser_types.start_section(); - let mut open_curly_pos = iter.last_valid_pos(); // bogus value - consume_comma_separated_until( - TokenKind::OpenCurly, &module.source, iter, ctx, - |source, iter, ctx| { - let poly_vars = ctx.heap[definition_id].poly_vars(); - self.type_parser.consume_parser_type( - iter, &ctx.heap, source, &ctx.symbols, poly_vars, definition_id, - module_scope, false, None - ) - }, - &mut return_types, "a return type", Some(&mut open_curly_pos) + let poly_vars = ctx.heap[definition_id].poly_vars(); + let parser_type = self.type_parser.consume_parser_type( + iter, &ctx.heap, &module.source, &ctx.symbols, poly_vars, definition_id, + module_scope, false, None )?; - let return_types = return_types.into_vec(); - - match return_types.len() { - 0 => return Err(ParseError::new_error_str_at_pos(&module.source, open_curly_pos, "expected a return type")), - 1 => {}, - _ => return Err(ParseError::new_error_str_at_pos(&module.source, open_curly_pos, "multiple return types are not (yet) allowed")), - } - // Consume block - let body = self.consume_block_statement_without_leading_curly(module, iter, ctx, open_curly_pos)?; + // Consume block and the definition's scope + let body_id = self.consume_block_statement(module, iter, ctx)?; + let scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::Definition(definition_id))); // Assign everything in the preallocated AST node - let function = ctx.heap[definition_id].as_function_mut(); - function.return_types = return_types; + let function = ctx.heap[definition_id].as_procedure_mut(); + function.return_type = Some(parser_type); function.parameters = parameters; - function.body = body; + function.scope = scope_id; + function.body = body_id; Ok(()) } @@ -330,197 +317,109 @@ impl PassDefinitions { let parameters = parameter_section.into_vec(); // Consume block - let body = self.consume_block_statement(module, iter, ctx)?; + let body_id = self.consume_block_statement(module, iter, ctx)?; + let scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::Definition(definition_id))); // Assign everything in the AST node - let component = ctx.heap[definition_id].as_component_mut(); + let component = ctx.heap[definition_id].as_procedure_mut(); + debug_assert!(component.return_type.is_none()); component.parameters = parameters; - component.body = body; + component.scope = scope_id; + component.body = body_id; Ok(()) } - /// Consumes a block statement. If the resulting statement is not a block - /// (e.g. for a shorthand "if (expr) single_statement") then it will be - /// wrapped in one - fn consume_block_or_wrapped_statement( - &mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx - ) -> Result { - if Some(TokenKind::OpenCurly) == iter.next() { - // This is a block statement - self.consume_block_statement(module, iter, ctx) - } else { - // Not a block statement, so wrap it in one - let mut statements = self.statements.start_section(); - let wrap_begin_pos = iter.last_valid_pos(); - self.consume_statement(module, iter, ctx, &mut statements)?; - let wrap_end_pos = iter.last_valid_pos(); - - let statements = statements.into_vec(); - - let id = ctx.heap.alloc_block_statement(|this| BlockStatement{ - this, - is_implicit: true, - span: InputSpan::from_positions(wrap_begin_pos, wrap_end_pos), - statements, - end_block: EndBlockStatementId::new_invalid(), - scope_node: ScopeNode::new_invalid(), - first_unique_id_in_scope: -1, - next_unique_id_in_scope: -1, - locals: Vec::new(), - labels: Vec::new(), - next: StatementId::new_invalid(), - }); - - let end_block = ctx.heap.alloc_end_block_statement(|this| EndBlockStatement{ - this, start_block: id, next: StatementId::new_invalid() - }); - - let block_stmt = &mut ctx.heap[id]; - block_stmt.end_block = end_block; - - Ok(id) - } - } - /// Consumes a statement and returns a boolean indicating whether it was a /// block or not. - fn consume_statement( - &mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx, section: &mut ScopedSection - ) -> Result<(), ParseError> { + fn consume_statement(&mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx) -> Result { let next = iter.next().expect("consume_statement has a next token"); if next == TokenKind::OpenCurly { let id = self.consume_block_statement(module, iter, ctx)?; - section.push(id.upcast()); + return Ok(id.upcast()); } else if next == TokenKind::Ident { let ident = peek_ident(&module.source, iter).unwrap(); if ident == KW_STMT_IF { // Consume if statement and place end-if statement directly // after it. let id = self.consume_if_statement(module, iter, ctx)?; - section.push(id.upcast()); - - let end_if = ctx.heap.alloc_end_if_statement(|this| EndIfStatement { - this, - start_if: id, - next: StatementId::new_invalid() - }); - section.push(end_if.upcast()); - - let if_stmt = &mut ctx.heap[id]; - if_stmt.end_if = end_if; + return Ok(id.upcast()); } else if ident == KW_STMT_WHILE { let id = self.consume_while_statement(module, iter, ctx)?; - section.push(id.upcast()); - - let end_while = ctx.heap.alloc_end_while_statement(|this| EndWhileStatement { - this, - start_while: id, - next: StatementId::new_invalid() - }); - section.push(end_while.upcast()); - - let while_stmt = &mut ctx.heap[id]; - while_stmt.end_while = end_while; + return Ok(id.upcast()); } else if ident == KW_STMT_BREAK { let id = self.consume_break_statement(module, iter, ctx)?; - section.push(id.upcast()); + return Ok(id.upcast()); } else if ident == KW_STMT_CONTINUE { let id = self.consume_continue_statement(module, iter, ctx)?; - section.push(id.upcast()); + return Ok(id.upcast()); } else if ident == KW_STMT_SYNC { let id = self.consume_synchronous_statement(module, iter, ctx)?; - section.push(id.upcast()); - - let end_sync = ctx.heap.alloc_end_synchronous_statement(|this| EndSynchronousStatement { - this, - start_sync: id, - next: StatementId::new_invalid() - }); - section.push(end_sync.upcast()); - - let sync_stmt = &mut ctx.heap[id]; - sync_stmt.end_sync = end_sync; + return Ok(id.upcast()); } else if ident == KW_STMT_FORK { let id = self.consume_fork_statement(module, iter, ctx)?; - section.push(id.upcast()); let end_fork = ctx.heap.alloc_end_fork_statement(|this| EndForkStatement { this, start_fork: id, next: StatementId::new_invalid(), }); - section.push(end_fork.upcast()); let fork_stmt = &mut ctx.heap[id]; fork_stmt.end_fork = end_fork; + + return Ok(id.upcast()); } else if ident == KW_STMT_SELECT { let id = self.consume_select_statement(module, iter, ctx)?; - section.push(id.upcast()); - - let end_select = ctx.heap.alloc_end_select_statement(|this| EndSelectStatement{ - this, - start_select: id, - next: StatementId::new_invalid(), - }); - section.push(end_select.upcast()); - - let select_stmt = &mut ctx.heap[id]; - select_stmt.end_select = end_select; + return Ok(id.upcast()); } else if ident == KW_STMT_RETURN { let id = self.consume_return_statement(module, iter, ctx)?; - section.push(id.upcast()); + return Ok(id.upcast()); } else if ident == KW_STMT_GOTO { let id = self.consume_goto_statement(module, iter, ctx)?; - section.push(id.upcast()); + return Ok(id.upcast()); } else if ident == KW_STMT_NEW { let id = self.consume_new_statement(module, iter, ctx)?; - section.push(id.upcast()); + return Ok(id.upcast()); } else if ident == KW_STMT_CHANNEL { let id = self.consume_channel_statement(module, iter, ctx)?; - section.push(id.upcast().upcast()); + return Ok(id.upcast().upcast()); } else if iter.peek() == Some(TokenKind::Colon) { - self.consume_labeled_statement(module, iter, ctx, section)?; + let id = self.consume_labeled_statement(module, iter, ctx)?; + return Ok(id.upcast()); } else { // Two fallback possibilities: the first one is a memory // declaration, the other one is to parse it as a normal // expression. This is a bit ugly. if let Some(memory_stmt_id) = self.maybe_consume_memory_statement_without_semicolon(module, iter, ctx)? { consume_token(&module.source, iter, TokenKind::SemiColon)?; - section.push(memory_stmt_id.upcast().upcast()); + return Ok(memory_stmt_id.upcast().upcast()); } else { let id = self.consume_expression_statement(module, iter, ctx)?; - section.push(id.upcast()); + return Ok(id.upcast()); } } } else if next == TokenKind::OpenParen { // Same as above: memory statement or normal expression if let Some(memory_stmt_id) = self.maybe_consume_memory_statement_without_semicolon(module, iter, ctx)? { consume_token(&module.source, iter, TokenKind::SemiColon)?; - section.push(memory_stmt_id.upcast().upcast()); + return Ok(memory_stmt_id.upcast().upcast()); } else { let id = self.consume_expression_statement(module, iter, ctx)?; - section.push(id.upcast()); + return Ok(id.upcast()); } } else { let id = self.consume_expression_statement(module, iter, ctx)?; - section.push(id.upcast()); + return Ok(id.upcast()); } - - return Ok(()); } fn consume_block_statement( &mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx ) -> Result { - let open_span = consume_token(&module.source, iter, TokenKind::OpenCurly)?; - self.consume_block_statement_without_leading_curly(module, iter, ctx, open_span.begin) - } + let open_curly_span = consume_token(&module.source, iter, TokenKind::OpenCurly)?; - fn consume_block_statement_without_leading_curly( - &mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx, open_curly_pos: InputPosition - ) -> Result { let mut stmt_section = self.statements.start_section(); let mut next = iter.next(); while next != Some(TokenKind::CloseCurly) { @@ -529,36 +428,34 @@ impl PassDefinitions { &module.source, iter.last_valid_pos(), "expected a statement or '}'" )); } - self.consume_statement(module, iter, ctx, &mut stmt_section)?; + let stmt_id = self.consume_statement(module, iter, ctx)?; + stmt_section.push(stmt_id); next = iter.next(); } let statements = stmt_section.into_vec(); let mut block_span = consume_token(&module.source, iter, TokenKind::CloseCurly)?; - block_span.begin = open_curly_pos; + block_span.begin = open_curly_span.begin; - let id = ctx.heap.alloc_block_statement(|this| BlockStatement{ + let block_id = ctx.heap.alloc_block_statement(|this| BlockStatement{ this, - is_implicit: false, span: block_span, statements, end_block: EndBlockStatementId::new_invalid(), - scope_node: ScopeNode::new_invalid(), - first_unique_id_in_scope: -1, - next_unique_id_in_scope: -1, - locals: Vec::new(), - labels: Vec::new(), + scope: ScopeId::new_invalid(), next: StatementId::new_invalid(), }); + let scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::Block(block_id))); - let end_block = ctx.heap.alloc_end_block_statement(|this| EndBlockStatement{ - this, start_block: id, next: StatementId::new_invalid() + let end_block_id = ctx.heap.alloc_end_block_statement(|this| EndBlockStatement{ + this, start_block: block_id, next: StatementId::new_invalid() }); - let block_stmt = &mut ctx.heap[id]; - block_stmt.end_block = end_block; + let block_stmt = &mut ctx.heap[block_id]; + block_stmt.end_block = end_block_id; + block_stmt.scope = scope_id; - Ok(id) + Ok(block_id) } fn consume_if_statement( @@ -568,24 +465,54 @@ impl PassDefinitions { consume_token(&module.source, iter, TokenKind::OpenParen)?; let test = self.consume_expression(module, iter, ctx)?; consume_token(&module.source, iter, TokenKind::CloseParen)?; - let true_body = self.consume_block_or_wrapped_statement(module, iter, ctx)?; + + // Consume bodies of if-statement + let true_body = IfStatementCase{ + body: self.consume_statement(module, iter, ctx)?, + scope: ScopeId::new_invalid(), + }; let false_body = if has_ident(&module.source, iter, KW_STMT_ELSE) { iter.consume(); - let false_body = self.consume_block_or_wrapped_statement(module, iter, ctx)?; + let false_body = IfStatementCase{ + body: self.consume_statement(module, iter, ctx)?, + scope: ScopeId::new_invalid(), + }; + Some(false_body) } else { None }; - Ok(ctx.heap.alloc_if_statement(|this| IfStatement{ + // Construct AST elements + let if_stmt_id = ctx.heap.alloc_if_statement(|this| IfStatement{ this, span: if_span, test, - true_body, - false_body, + true_case: true_body, + false_case: false_body, end_if: EndIfStatementId::new_invalid(), - })) + }); + let end_if_stmt_id = ctx.heap.alloc_end_if_statement(|this| EndIfStatement{ + this, + start_if: if_stmt_id, + next: StatementId::new_invalid(), + }); + let true_scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::If(if_stmt_id, true))); + let false_scope_id = if false_body.is_some() { + Some(ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::If(if_stmt_id, false)))) + } else { + None + }; + + let if_stmt = &mut ctx.heap[if_stmt_id]; + if_stmt.end_if = end_if_stmt_id; + if_stmt.true_case.scope = true_scope_id; + if let Some(false_case) = &mut if_stmt.false_case { + false_case.scope = false_scope_id.unwrap(); + } + + return Ok(if_stmt_id); } fn consume_while_statement( @@ -595,16 +522,29 @@ impl PassDefinitions { consume_token(&module.source, iter, TokenKind::OpenParen)?; let test = self.consume_expression(module, iter, ctx)?; consume_token(&module.source, iter, TokenKind::CloseParen)?; - let body = self.consume_block_or_wrapped_statement(module, iter, ctx)?; + let body = self.consume_statement(module, iter, ctx)?; - Ok(ctx.heap.alloc_while_statement(|this| WhileStatement{ + let while_stmt_id = ctx.heap.alloc_while_statement(|this| WhileStatement{ this, span: while_span, test, + scope: ScopeId::new_invalid(), body, end_while: EndWhileStatementId::new_invalid(), in_sync: SynchronousStatementId::new_invalid(), - })) + }); + let end_while_stmt_id = ctx.heap.alloc_end_while_statement(|this| EndWhileStatement{ + this, + start_while: while_stmt_id, + next: StatementId::new_invalid(), + }); + let scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::While(while_stmt_id))); + + let while_stmt = &mut ctx.heap[while_stmt_id]; + while_stmt.scope = scope_id; + while_stmt.end_while = end_while_stmt_id; + + Ok(while_stmt_id) } fn consume_break_statement( @@ -649,25 +589,38 @@ impl PassDefinitions { &mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx ) -> Result { let synchronous_span = consume_exact_ident(&module.source, iter, KW_STMT_SYNC)?; - let body = self.consume_block_or_wrapped_statement(module, iter, ctx)?; + let body = self.consume_statement(module, iter, ctx)?; - Ok(ctx.heap.alloc_synchronous_statement(|this| SynchronousStatement{ + let sync_stmt_id = ctx.heap.alloc_synchronous_statement(|this| SynchronousStatement{ this, span: synchronous_span, + scope: ScopeId::new_invalid(), body, end_sync: EndSynchronousStatementId::new_invalid(), - })) + }); + let end_sync_stmt_id = ctx.heap.alloc_end_synchronous_statement(|this| EndSynchronousStatement{ + this, + start_sync: sync_stmt_id, + next: StatementId::new_invalid(), + }); + let scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::Synchronous(sync_stmt_id))); + + let sync_stmt = &mut ctx.heap[sync_stmt_id]; + sync_stmt.scope = scope_id; + sync_stmt.end_sync = end_sync_stmt_id; + + return Ok(sync_stmt_id); } fn consume_fork_statement( &mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx ) -> Result { let fork_span = consume_exact_ident(&module.source, iter, KW_STMT_FORK)?; - let left_body = self.consume_block_or_wrapped_statement(module, iter, ctx)?; + let left_body = self.consume_statement(module, iter, ctx)?; let right_body = if has_ident(&module.source, iter, KW_STMT_OR) { iter.consume(); - let right_body = self.consume_block_or_wrapped_statement(module, iter, ctx)?; + let right_body = self.consume_statement(module, iter, ctx)?; Some(right_body) } else { None @@ -710,9 +663,11 @@ impl PassDefinitions { }, }; consume_token(&module.source, iter, TokenKind::ArrowRight)?; - let block = self.consume_block_or_wrapped_statement(module, iter, ctx)?; + let block = self.consume_statement(module, iter, ctx)?; cases.push(SelectCase{ - guard, block, + guard, + body: block, + scope: ScopeId::new_invalid(), involved_ports: Vec::with_capacity(1) }); @@ -721,12 +676,33 @@ impl PassDefinitions { consume_token(&module.source, iter, TokenKind::CloseCurly)?; - Ok(ctx.heap.alloc_select_statement(|this| SelectStatement{ + let num_cases = cases.len(); + let select_stmt_id = ctx.heap.alloc_select_statement(|this| SelectStatement{ this, span: select_span, cases, end_select: EndSelectStatementId::new_invalid(), - })) + relative_pos_in_parent: -1, + next: StatementId::new_invalid(), + }); + + let end_select_stmt_id = ctx.heap.alloc_end_select_statement(|this| EndSelectStatement{ + this, + start_select: select_stmt_id, + next: StatementId::new_invalid(), + }); + + let select_stmt = &mut ctx.heap[select_stmt_id]; + select_stmt.end_select = end_select_stmt_id; + + for case_index in 0..num_cases { + let scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::SelectCase(select_stmt_id, case_index as u32))); + let select_stmt = &mut ctx.heap[select_stmt_id]; + let select_case = &mut select_stmt.cases[case_index]; + select_case.scope = scope_id; + } + + return Ok(select_stmt_id) } fn consume_return_statement( @@ -855,7 +831,7 @@ impl PassDefinitions { kind: VariableKind::Local, identifier: from_identifier, parser_type: from_port_type, - relative_pos_in_block: 0, + relative_pos_in_parent: 0, unique_id_in_scope: -1, }); @@ -870,7 +846,7 @@ impl PassDefinitions { kind: VariableKind::Local, identifier: to_identifier, parser_type: to_port_type, - relative_pos_in_block: 0, + relative_pos_in_parent: 0, unique_id_in_scope: -1, }); @@ -879,49 +855,25 @@ impl PassDefinitions { this, span: channel_span, from, to, - relative_pos_in_block: 0, + relative_pos_in_parent: 0, next: StatementId::new_invalid(), })) } - fn consume_labeled_statement( - &mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx, section: &mut ScopedSection - ) -> Result<(), ParseError> { + fn consume_labeled_statement(&mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx) -> Result { let label = consume_ident_interned(&module.source, iter, ctx)?; consume_token(&module.source, iter, TokenKind::Colon)?; - // Not pretty: consume_statement may produce more than one statement. - // The values in the section need to be in the correct order if some - // kind of outer block is consumed, so we take another section, push - // the expressions in that one, and then allocate the labeled statement. - let mut inner_section = self.statements.start_section(); - self.consume_statement(module, iter, ctx, &mut inner_section)?; - debug_assert!(inner_section.len() >= 1); - + let inner_stmt_id = self.consume_statement(module, iter, ctx)?; let stmt_id = ctx.heap.alloc_labeled_statement(|this| LabeledStatement { this, label, - body: inner_section[0], - relative_pos_in_block: 0, + body: inner_stmt_id, + relative_pos_in_parent: 0, in_sync: SynchronousStatementId::new_invalid(), }); - if inner_section.len() == 1 { - // Produce the labeled statement pointing to the first statement. - // This is by far the most common case. - inner_section.forget(); - section.push(stmt_id.upcast()); - } else { - // Produce the labeled statement using the first statement, and push - // the remaining ones at the end. - let inner_statements = inner_section.into_vec(); - section.push(stmt_id.upcast()); - for idx in 1..inner_statements.len() { - section.push(inner_statements[idx]) - } - } - - Ok(()) + return Ok(stmt_id); } /// Attempts to consume a memory statement (a statement along the lines of @@ -960,7 +912,7 @@ impl PassDefinitions { kind: VariableKind::Local, identifier: identifier.clone(), parser_type, - relative_pos_in_block: 0, + relative_pos_in_parent: 0, unique_id_in_scope: -1, }); @@ -972,7 +924,7 @@ impl PassDefinitions { declaration: Some(local_id), used_as_binding_target: false, parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }); let assignment_expr_id = ctx.heap.alloc_assignment_expression(|this| AssignmentExpression{ this, @@ -982,7 +934,7 @@ impl PassDefinitions { operation: AssignmentOperator::Set, right: initial_expr_id, parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }); // Put both together in the memory statement @@ -1077,7 +1029,7 @@ impl PassDefinitions { Ok(ctx.heap.alloc_assignment_expression(|this| AssignmentExpression{ this, operator_span, full_span, left, operation, right, parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast()) } else { Ok(expr) @@ -1105,7 +1057,7 @@ impl PassDefinitions { Ok(ctx.heap.alloc_conditional_expression(|this| ConditionalExpression{ this, operator_span, full_span, test, true_expression, false_expression, parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast()) } else { Ok(result) @@ -1290,7 +1242,7 @@ impl PassDefinitions { Ok(ctx.heap.alloc_unary_expression(|this| UnaryExpression { this, operator_span, full_span, operation, expression, parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast()) } else if next == Some(TokenKind::PlusPlus) { return Err(ParseError::new_error_str_at_span( @@ -1354,7 +1306,7 @@ impl PassDefinitions { slicing_span: operator_span, full_span, subject, from_index, to_index, parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast(); } else if Some(TokenKind::CloseSquare) == next { let end_span = consume_token(&module.source, iter, TokenKind::CloseSquare)?; @@ -1368,7 +1320,7 @@ impl PassDefinitions { this, operator_span, full_span, subject, index: from_index, parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast(); } else { return Err(ParseError::new_error_str_at_pos( @@ -1409,7 +1361,7 @@ impl PassDefinitions { this, operator_span, full_span, subject, kind: select_kind, parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast(); } @@ -1445,7 +1397,7 @@ impl PassDefinitions { span: InputSpan::from_positions(open_paren_pos, close_paren_pos), value: Literal::Tuple(Vec::new()), parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }); literal_id.upcast() @@ -1471,7 +1423,7 @@ impl PassDefinitions { span: InputSpan::from_positions(open_paren_pos, close_paren_pos), value: Literal::Tuple(scoped_section.into_vec()), parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }); literal_id.upcast() @@ -1499,7 +1451,7 @@ impl PassDefinitions { span: InputSpan::from_positions(start_pos, end_pos), value: Literal::Array(scoped_section.into_vec()), parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast() } else if next == Some(TokenKind::Integer) { let (literal, span) = consume_integer_literal(&module.source, iter, &mut self.buffer)?; @@ -1508,7 +1460,7 @@ impl PassDefinitions { this, span, value: Literal::Integer(LiteralInteger{ unsigned_value: literal, negated: false }), parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast() } else if next == Some(TokenKind::String) { let span = consume_string_literal(&module.source, iter, &mut self.buffer)?; @@ -1518,7 +1470,7 @@ impl PassDefinitions { this, span, value: Literal::String(interned), parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast() } else if next == Some(TokenKind::Character) { let (character, span) = consume_character_literal(&module.source, iter)?; @@ -1527,7 +1479,7 @@ impl PassDefinitions { this, span, value: Literal::Character(character), parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast() } else if next == Some(TokenKind::Ident) { // May be a variable, a type instantiation or a function call. If we @@ -1579,7 +1531,7 @@ impl PassDefinitions { definition: target_definition_id, }), parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast() }, Definition::Enum(_) => { @@ -1597,7 +1549,7 @@ impl PassDefinitions { variant_idx: 0 }), parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast() }, Definition::Union(_) => { @@ -1622,31 +1574,14 @@ impl PassDefinitions { variant_idx: 0, }), parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast() }, - Definition::Component(_) => { - // Component instantiation - let func_span = parser_type.full_span; - let mut full_span = func_span; - let arguments = self.consume_expression_list( - module, iter, ctx, Some(&mut full_span.end) - )?; - - ctx.heap.alloc_call_expression(|this| CallExpression{ - this, func_span, full_span, - parser_type, - method: Method::UserComponent, - arguments, - definition: target_definition_id, - parent: ExpressionParent::None, - unique_id_in_definition: -1, - }).upcast() - }, - Definition::Function(function_definition) => { + Definition::Procedure(proc_def) => { // Check whether it is a builtin function - let method = if function_definition.builtin { - match function_definition.identifier.value.as_bytes() { + let procedure_id = proc_def.this; + let method = if proc_def.builtin { + match proc_def.identifier.value.as_bytes() { KW_FUNC_GET => Method::Get, KW_FUNC_PUT => Method::Put, KW_FUNC_FIRES => Method::Fires, @@ -1656,8 +1591,10 @@ impl PassDefinitions { KW_FUNC_PRINT => Method::Print, _ => unreachable!(), } - } else { + } else if proc_def.kind == ProcedureKind::Function { Method::UserFunction + } else { + Method::UserComponent }; // Function call: consume the arguments @@ -1669,9 +1606,9 @@ impl PassDefinitions { ctx.heap.alloc_call_expression(|this| CallExpression{ this, func_span, full_span, parser_type, method, arguments, - definition: target_definition_id, + procedure: procedure_id, parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast() } } @@ -1700,7 +1637,7 @@ impl PassDefinitions { span: ident_span, value, parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast() } else if ident_text == KW_LET { // Binding expression @@ -1718,7 +1655,7 @@ impl PassDefinitions { ctx.heap.alloc_binding_expression(|this| BindingExpression{ this, operator_span, full_span, bound_to, bound_from, parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast() } else if ident_text == KW_CAST { // Casting expression @@ -1755,7 +1692,7 @@ impl PassDefinitions { cast_span: to_type.full_span, full_span, to_type, subject, parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast() } else { // Not a builtin literal, but also not a known type. So we @@ -1790,7 +1727,7 @@ impl PassDefinitions { declaration: None, used_as_binding_target: false, parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast() } } @@ -1830,7 +1767,7 @@ impl PassDefinitions { result = ctx.heap.alloc_binary_expression(|this| BinaryExpression{ this, operator_span, full_span, left, operation, right, parent: ExpressionParent::None, - unique_id_in_definition: -1, + type_index: -1, }).upcast(); } @@ -1883,7 +1820,7 @@ fn consume_parameter_list( kind: VariableKind::Parameter, parser_type, identifier, - relative_pos_in_block: 0, + relative_pos_in_parent: 0, unique_id_in_scope: -1, }); Ok(parameter_id) diff --git a/src/protocol/parser/pass_rewriting.rs b/src/protocol/parser/pass_rewriting.rs new file mode 100644 index 0000000000000000000000000000000000000000..ccbfca20870bad54d815d7c84983fd2c025bb478 --- /dev/null +++ b/src/protocol/parser/pass_rewriting.rs @@ -0,0 +1,683 @@ +use crate::collections::*; +use crate::protocol::*; + +use super::visitor::*; + +pub(crate) struct PassRewriting { + current_scope: ScopeId, + current_procedure_id: ProcedureDefinitionId, + definition_buffer: ScopedBuffer, + statement_buffer: ScopedBuffer, + call_expr_buffer: ScopedBuffer, + expression_buffer: ScopedBuffer, + scope_buffer: ScopedBuffer, +} + +impl PassRewriting { + pub(crate) fn new() -> Self { + Self{ + current_scope: ScopeId::new_invalid(), + current_procedure_id: ProcedureDefinitionId::new_invalid(), + definition_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_LARGE), + statement_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL), + call_expr_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL), + expression_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL), + scope_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL), + } + } +} + +impl Visitor for PassRewriting { + fn visit_module(&mut self, ctx: &mut Ctx) -> VisitorResult { + let module = ctx.module(); + debug_assert_eq!(module.phase, ModuleCompilationPhase::Typed); + + let root_id = module.root_id; + let root = &ctx.heap[root_id]; + let definition_section = self.definition_buffer.start_section_initialized(&root.definitions); + for definition_index in 0..definition_section.len() { + let definition_id = definition_section[definition_index]; + self.visit_definition(ctx, definition_id)?; + } + + definition_section.forget(); + ctx.module_mut().phase = ModuleCompilationPhase::Rewritten; + return Ok(()) + } + + // --- Visiting procedures + + fn visit_procedure_definition(&mut self, ctx: &mut Ctx, id: ProcedureDefinitionId) -> VisitorResult { + let definition = &ctx.heap[id]; + let body_id = definition.body; + self.current_scope = definition.scope; + self.current_procedure_id = id; + return self.visit_block_stmt(ctx, body_id); + } + + // --- Visiting statements (that are not the select statement) + + fn visit_block_stmt(&mut self, ctx: &mut Ctx, id: BlockStatementId) -> VisitorResult { + let block_stmt = &ctx.heap[id]; + let stmt_section = self.statement_buffer.start_section_initialized(&block_stmt.statements); + + self.current_scope = block_stmt.scope; + for stmt_idx in 0..stmt_section.len() { + self.visit_stmt(ctx, stmt_section[stmt_idx])?; + } + + stmt_section.forget(); + return Ok(()) + } + + fn visit_labeled_stmt(&mut self, ctx: &mut Ctx, id: LabeledStatementId) -> VisitorResult { + let labeled_stmt = &ctx.heap[id]; + let body_id = labeled_stmt.body; + return self.visit_stmt(ctx, body_id); + } + + fn visit_if_stmt(&mut self, ctx: &mut Ctx, id: IfStatementId) -> VisitorResult { + let if_stmt = &ctx.heap[id]; + let true_case = if_stmt.true_case; + let false_case = if_stmt.false_case; + + self.current_scope = true_case.scope; + self.visit_stmt(ctx, true_case.body)?; + if let Some(false_case) = false_case { + self.current_scope = false_case.scope; + self.visit_stmt(ctx, false_case.body)?; + } + + return Ok(()) + } + + fn visit_while_stmt(&mut self, ctx: &mut Ctx, id: WhileStatementId) -> VisitorResult { + let while_stmt = &ctx.heap[id]; + let body_id = while_stmt.body; + self.current_scope = while_stmt.scope; + return self.visit_stmt(ctx, body_id); + } + + fn visit_synchronous_stmt(&mut self, ctx: &mut Ctx, id: SynchronousStatementId) -> VisitorResult { + let sync_stmt = &ctx.heap[id]; + let body_id = sync_stmt.body; + self.current_scope = sync_stmt.scope; + return self.visit_stmt(ctx, body_id); + } + + // --- Visiting the select statement + + fn visit_select_stmt(&mut self, ctx: &mut Ctx, id: SelectStatementId) -> VisitorResult { + // Utility for the last stage of rewriting process. Note that caller + // still needs to point the end of the if-statement to the end of the + // replacement statement of the select statement. + fn transform_select_case_code( + ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, + select_id: SelectStatementId, case_index: usize, + select_var_id: VariableId, select_var_type_id: TypeIdReference + ) -> (IfStatementId, EndIfStatementId, ScopeId) { + // Retrieve statement IDs associated with case + let case = &ctx.heap[select_id].cases[case_index]; + let case_guard_id = case.guard; + let case_body_id = case.body; + let case_scope_id = case.scope; + + // Create the if-statement for the result of the select statement + let compare_expr_id = create_ast_equality_comparison_expr(ctx, containing_procedure_id, select_var_id, select_var_type_id, case_index as u64); + let true_case = IfStatementCase{ + body: case_guard_id, // which is linked up to the body + scope: case_scope_id, + }; + let (if_stmt_id, end_if_stmt_id) = create_ast_if_stmt(ctx, compare_expr_id.upcast(), true_case, None); + + // Link up body statement to end-if + set_ast_statement_next(ctx, case_body_id, end_if_stmt_id.upcast()); + + return (if_stmt_id, end_if_stmt_id, case_scope_id); + } + + // Precreate the block that will end up containing all of the + // transformed statements. Also precreate the scope associated with it + let (outer_block_id, outer_end_block_id, outer_scope_id) = + create_ast_block_stmt(ctx, Vec::new()); + + // The "select" and the "end select" statement will act like trampolines + // that jump to the replacement block. So set the child/parent + // relationship already. + // --- for the statements + let select_stmt = &mut ctx.heap[id]; + select_stmt.next = outer_block_id.upcast(); + let end_select_stmt_id = select_stmt.end_select; + let select_stmt_relative_pos = select_stmt.relative_pos_in_parent; + + let outer_end_block_stmt = &mut ctx.heap[outer_end_block_id]; + outer_end_block_stmt.next = end_select_stmt_id.upcast(); + + // --- for the scopes + link_new_child_to_existing_parent_scope(ctx, &mut self.scope_buffer, self.current_scope, outer_scope_id, select_stmt_relative_pos); + + // Create statements that will create temporary variables for all of the + // ports passed to the "get" calls in the select case guards. + let select_stmt = &ctx.heap[id]; + let total_num_cases = select_stmt.cases.len(); + let mut total_num_ports = 0; + let end_select_stmt_id = select_stmt.end_select; + let _end_select = &ctx.heap[end_select_stmt_id]; + + // Put heap IDs into temporary buffers to handle borrowing rules + let mut call_id_section = self.call_expr_buffer.start_section(); + let mut expr_id_section = self.expression_buffer.start_section(); + + for case in select_stmt.cases.iter() { + total_num_ports += case.involved_ports.len(); + for (call_id, expr_id) in case.involved_ports.iter().copied() { + call_id_section.push(call_id); + expr_id_section.push(expr_id); + } + } + + // Transform all of the call expressions by takings its argument (the + // port from which we `get`) and turning it into a temporary variable. + let mut transformed_stmts = Vec::with_capacity(total_num_ports); // TODO: Recompute this preallocated length, put assert at the end + let mut locals = Vec::with_capacity(total_num_ports); + + for port_var_idx in 0..call_id_section.len() { + let get_call_expr_id = call_id_section[port_var_idx]; + let port_expr_id = expr_id_section[port_var_idx]; + let port_type_index = ctx.heap[port_expr_id].type_index(); + let port_type_ref = TypeIdReference::IndirectSameAsExpr(port_type_index); + + // Move the port expression such that it gets assigned to a temporary variable + let variable_id = create_ast_variable(ctx, outer_scope_id); + let variable_decl_stmt_id = create_ast_variable_declaration_stmt(ctx, self.current_procedure_id, variable_id, port_type_ref, port_expr_id); + + // Replace the original port expression in the call with a reference + // to the replacement variable + let variable_expr_id = create_ast_variable_expr(ctx, self.current_procedure_id, variable_id, port_type_ref); + let call_expr = &mut ctx.heap[get_call_expr_id]; + call_expr.arguments[0] = variable_expr_id.upcast(); + + transformed_stmts.push(variable_decl_stmt_id.upcast().upcast()); + locals.push((variable_id, port_type_ref)); + } + + // Insert runtime calls that facilitate the semantics of the select + // block. + + // Create the call that indicates the start of the select block + { + let num_cases_expression_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, total_num_cases as u64, ctx.arch.uint32_type_id); + let num_ports_expression_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, total_num_ports as u64, ctx.arch.uint32_type_id); + let arguments = vec![ + num_cases_expression_id.upcast(), + num_ports_expression_id.upcast() + ]; + + let call_expression_id = create_ast_call_expr(ctx, self.current_procedure_id, Method::SelectStart, &mut self.expression_buffer, arguments); + let call_statement_id = create_ast_expression_stmt(ctx, call_expression_id.upcast()); + + transformed_stmts.push(call_statement_id.upcast()); + } + + // Create calls for each select case that will register the ports that + // we are waiting on at the runtime. + { + let mut total_port_index = 0; + for case_index in 0..total_num_cases { + let case = &ctx.heap[id].cases[case_index]; + let case_num_ports = case.involved_ports.len(); + + for case_port_index in 0..case_num_ports { + // Arguments to runtime call + let (port_variable_id, port_variable_type) = locals[total_port_index]; // so far this variable contains the temporary variables for the port expressions + let case_index_expr_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, case_index as u64, ctx.arch.uint32_type_id); + let port_index_expr_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, case_port_index as u64, ctx.arch.uint32_type_id); + let port_variable_expr_id = create_ast_variable_expr(ctx, self.current_procedure_id, port_variable_id, port_variable_type); + let runtime_call_arguments = vec![ + case_index_expr_id.upcast(), + port_index_expr_id.upcast(), + port_variable_expr_id.upcast() + ]; + + // Create runtime call, then store it + let runtime_call_expr_id = create_ast_call_expr(ctx, self.current_procedure_id, Method::SelectRegisterCasePort, &mut self.expression_buffer, runtime_call_arguments); + let runtime_call_stmt_id = create_ast_expression_stmt(ctx, runtime_call_expr_id.upcast()); + + transformed_stmts.push(runtime_call_stmt_id.upcast()); + + total_port_index += 1; + } + } + } + + // Create the variable that will hold the result of a completed select + // block. Then create the runtime call that will produce this result + let select_variable_id = create_ast_variable(ctx, outer_scope_id); + let select_variable_type = TypeIdReference::DirectTypeId(ctx.arch.uint32_type_id); + locals.push((select_variable_id, select_variable_type)); + + { + let runtime_call_expr_id = create_ast_call_expr(ctx, self.current_procedure_id, Method::SelectWait, &mut self.expression_buffer, Vec::new()); + let variable_stmt_id = create_ast_variable_declaration_stmt(ctx, self.current_procedure_id, select_variable_id, select_variable_type, runtime_call_expr_id.upcast()); + transformed_stmts.push(variable_stmt_id.upcast().upcast()); + } + + call_id_section.forget(); + expr_id_section.forget(); + + // Now we transform each of the select block case's guard and code into + // a chained if-else statement. + let mut relative_pos = transformed_stmts.len() as i32; + if total_num_cases > 0 { + let (if_stmt_id, end_if_stmt_id, scope_id) = transform_select_case_code(ctx, self.current_procedure_id, id, 0, select_variable_id, select_variable_type); + link_existing_child_to_new_parent_scope(ctx, &mut self.scope_buffer, outer_scope_id, scope_id, relative_pos); + let first_end_if_stmt = &mut ctx.heap[end_if_stmt_id]; + first_end_if_stmt.next = outer_end_block_id.upcast(); + + let mut last_if_stmt_id = if_stmt_id; + let mut last_end_if_stmt_id = end_if_stmt_id; + let mut last_parent_scope_id = outer_scope_id; + let mut last_relative_pos = transformed_stmts.len() as i32 + 1; + transformed_stmts.push(last_if_stmt_id.upcast()); + + for case_index in 1..total_num_cases { + let (if_stmt_id, end_if_stmt_id, scope_id) = transform_select_case_code(ctx, self.current_procedure_id, id, case_index, select_variable_id, select_variable_type); + let false_case_scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::If(last_if_stmt_id, false))); + link_existing_child_to_new_parent_scope(ctx, &mut self.scope_buffer, false_case_scope_id, scope_id, 0); + link_new_child_to_existing_parent_scope(ctx, &mut self.scope_buffer, last_parent_scope_id, false_case_scope_id, last_relative_pos); + set_ast_if_statement_false_body(ctx, last_if_stmt_id, last_end_if_stmt_id, IfStatementCase{ body: if_stmt_id.upcast(), scope: false_case_scope_id }); + + let end_if_stmt = &mut ctx.heap[end_if_stmt_id]; + end_if_stmt.next = last_end_if_stmt_id.upcast(); + + last_if_stmt_id = if_stmt_id; + last_end_if_stmt_id = end_if_stmt_id; + last_parent_scope_id = false_case_scope_id; + last_relative_pos = 0; + } + } + + // Final steps: set the statements of the replacement block statement, + // link all of those statements together, and update the scopes. + let first_stmt_id = transformed_stmts[0]; + let mut last_stmt_id = transformed_stmts[0]; + for stmt_id in transformed_stmts.iter().skip(1).copied() { + set_ast_statement_next(ctx, last_stmt_id, stmt_id); + last_stmt_id = stmt_id; + } + + if total_num_cases == 0 { + // If we don't have any cases, then we didn't connect the statements + // up to the end of the outer block, so do that here + set_ast_statement_next(ctx, last_stmt_id, outer_end_block_id.upcast()); + } + + let outer_block_stmt = &mut ctx.heap[outer_block_id]; + outer_block_stmt.next = first_stmt_id; + outer_block_stmt.statements = transformed_stmts; + + return Ok(()) + } +} + +// ----------------------------------------------------------------------------- +// Utilities to create compiler-generated AST nodes +// ----------------------------------------------------------------------------- + +#[derive(Clone, Copy)] +enum TypeIdReference { + DirectTypeId(TypeId), + IndirectSameAsExpr(i32), // by type index +} + +fn create_ast_variable(ctx: &mut Ctx, scope_id: ScopeId) -> VariableId { + let variable_id = ctx.heap.alloc_variable(|this| Variable{ + this, + kind: VariableKind::Local, + parser_type: ParserType{ + elements: Vec::new(), + full_span: InputSpan::new(), + }, + identifier: Identifier::new_empty(InputSpan::new()), + relative_pos_in_parent: -1, + unique_id_in_scope: -1, + }); + let scope = &mut ctx.heap[scope_id]; + scope.variables.push(variable_id); + + return variable_id; +} + +fn create_ast_variable_expr(ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, variable_id: VariableId, variable_type_id: TypeIdReference) -> VariableExpressionId { + let variable_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, variable_type_id); + return ctx.heap.alloc_variable_expression(|this| VariableExpression{ + this, + identifier: Identifier::new_empty(InputSpan::new()), + declaration: Some(variable_id), + used_as_binding_target: false, + parent: ExpressionParent::None, + type_index: variable_type_index, + }); +} + +fn create_ast_call_expr(ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, method: Method, buffer: &mut ScopedBuffer, arguments: Vec) -> CallExpressionId { + let call_type_id = match method { + Method::SelectStart => ctx.arch.void_type_id, + Method::SelectRegisterCasePort => ctx.arch.void_type_id, + Method::SelectWait => ctx.arch.uint32_type_id, // TODO: Not pretty, this. Pretty error prone + _ => unreachable!(), // if this goes of, add the appropriate method here. + }; + + let expression_ids = buffer.start_section_initialized(&arguments); + let call_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, TypeIdReference::DirectTypeId(call_type_id)); + let call_expression_id = ctx.heap.alloc_call_expression(|this| CallExpression{ + func_span: InputSpan::new(), + this, + full_span: InputSpan::new(), + parser_type: ParserType{ + elements: Vec::new(), + full_span: InputSpan::new(), + }, + method, + arguments, + procedure: ProcedureDefinitionId::new_invalid(), + parent: ExpressionParent::None, + type_index: call_type_index, + }); + + for argument_index in 0..expression_ids.len() { + let argument_id = expression_ids[argument_index]; + let argument_expr = &mut ctx.heap[argument_id]; + *argument_expr.parent_mut() = ExpressionParent::Expression(call_expression_id.upcast(), argument_index as u32); + } + + return call_expression_id; +} + +fn create_ast_literal_integer_expr(ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, unsigned_value: u64, type_id: TypeId) -> LiteralExpressionId { + let literal_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, TypeIdReference::DirectTypeId(type_id)); + return ctx.heap.alloc_literal_expression(|this| LiteralExpression{ + this, + span: InputSpan::new(), + value: Literal::Integer(LiteralInteger{ + unsigned_value, + negated: false, + }), + parent: ExpressionParent::None, + type_index: literal_type_index, + }); +} + +fn create_ast_equality_comparison_expr( + ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, + variable_id: VariableId, variable_type: TypeIdReference, value: u64 +) -> BinaryExpressionId { + let var_expr_id = create_ast_variable_expr(ctx, containing_procedure_id, variable_id, variable_type); + let int_expr_id = create_ast_literal_integer_expr(ctx, containing_procedure_id, value, ctx.arch.uint32_type_id); + let cmp_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, TypeIdReference::DirectTypeId(ctx.arch.bool_type_id)); + let cmp_expr_id = ctx.heap.alloc_binary_expression(|this| BinaryExpression{ + this, + operator_span: InputSpan::new(), + full_span: InputSpan::new(), + left: var_expr_id.upcast(), + operation: BinaryOperator::Equality, + right: int_expr_id.upcast(), + parent: ExpressionParent::None, + type_index: cmp_type_index, + }); + + let var_expr = &mut ctx.heap[var_expr_id]; + var_expr.parent = ExpressionParent::Expression(cmp_expr_id.upcast(), 0); + let int_expr = &mut ctx.heap[int_expr_id]; + int_expr.parent = ExpressionParent::Expression(cmp_expr_id.upcast(), 1); + + return cmp_expr_id; +} + +fn create_ast_expression_stmt(ctx: &mut Ctx, expression_id: ExpressionId) -> ExpressionStatementId { + let statement_id = ctx.heap.alloc_expression_statement(|this| ExpressionStatement{ + this, + span: InputSpan::new(), + expression: expression_id, + next: StatementId::new_invalid(), + }); + + let expression = &mut ctx.heap[expression_id]; + *expression.parent_mut() = ExpressionParent::ExpressionStmt(statement_id); + + return statement_id; +} + +fn create_ast_variable_declaration_stmt( + ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, + variable_id: VariableId, variable_type: TypeIdReference, initial_value_expr_id: ExpressionId +) -> MemoryStatementId { + // Create the assignment expression, assigning the initial value to the variable + let variable_expr_id = create_ast_variable_expr(ctx, containing_procedure_id, variable_id, variable_type); + let void_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, TypeIdReference::DirectTypeId(ctx.arch.void_type_id)); + let assignment_expr_id = ctx.heap.alloc_assignment_expression(|this| AssignmentExpression{ + this, + operator_span: InputSpan::new(), + full_span: InputSpan::new(), + left: variable_expr_id.upcast(), + operation: AssignmentOperator::Set, + right: initial_value_expr_id, + parent: ExpressionParent::None, + type_index: void_type_index, + }); + + // Create the memory statement + let memory_stmt_id = ctx.heap.alloc_memory_statement(|this| MemoryStatement{ + this, + span: InputSpan::new(), + variable: variable_id, + initial_expr: assignment_expr_id, + next: StatementId::new_invalid(), + }); + + // Set all parents which we can access + let variable_expr = &mut ctx.heap[variable_expr_id]; + variable_expr.parent = ExpressionParent::Expression(assignment_expr_id.upcast(), 0); + let value_expr = &mut ctx.heap[initial_value_expr_id]; + *value_expr.parent_mut() = ExpressionParent::Expression(assignment_expr_id.upcast(), 1); + let assignment_expr = &mut ctx.heap[assignment_expr_id]; + assignment_expr.parent = ExpressionParent::Memory(memory_stmt_id); + + return memory_stmt_id; +} + +fn create_ast_block_stmt(ctx: &mut Ctx, statements: Vec) -> (BlockStatementId, EndBlockStatementId, ScopeId) { + let block_stmt_id = ctx.heap.alloc_block_statement(|this| BlockStatement{ + this, + span: InputSpan::new(), + statements, + end_block: EndBlockStatementId::new_invalid(), + scope: ScopeId::new_invalid(), + next: StatementId::new_invalid(), + }); + let end_block_stmt_id = ctx.heap.alloc_end_block_statement(|this| EndBlockStatement{ + this, + start_block: block_stmt_id, + next: StatementId::new_invalid(), + }); + let scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::Block(block_stmt_id))); + + let block_stmt = &mut ctx.heap[block_stmt_id]; + block_stmt.end_block = end_block_stmt_id; + block_stmt.scope = scope_id; + + return (block_stmt_id, end_block_stmt_id, scope_id); +} + +fn create_ast_if_stmt(ctx: &mut Ctx, condition_expression_id: ExpressionId, true_case: IfStatementCase, false_case: Option) -> (IfStatementId, EndIfStatementId) { + // Create if statement and the end-if statement + let if_stmt_id = ctx.heap.alloc_if_statement(|this| IfStatement{ + this, + span: InputSpan::new(), + test: condition_expression_id, + true_case, + false_case, + end_if: EndIfStatementId::new_invalid() + }); + + let end_if_stmt_id = ctx.heap.alloc_end_if_statement(|this| EndIfStatement{ + this, + start_if: if_stmt_id, + next: StatementId::new_invalid(), + }); + + // Link the statements up as much as we can + let if_stmt = &mut ctx.heap[if_stmt_id]; + if_stmt.end_if = end_if_stmt_id; + + let condition_expr = &mut ctx.heap[condition_expression_id]; + *condition_expr.parent_mut() = ExpressionParent::If(if_stmt_id); + + + + return (if_stmt_id, end_if_stmt_id); +} + +/// Sets the false body for a given +fn set_ast_if_statement_false_body(ctx: &mut Ctx, if_statement_id: IfStatementId, end_if_statement_id: EndIfStatementId, false_case: IfStatementCase) { + // Point if-statement to "false body" + let if_stmt = &mut ctx.heap[if_statement_id]; + debug_assert!(if_stmt.false_case.is_none()); // simplifies logic, not necessary + if_stmt.false_case = Some(false_case); + + // Point end of false body to the end of the if statement + set_ast_statement_next(ctx, false_case.body, end_if_statement_id.upcast()); +} + +/// Sets the specified AST statement's control flow such that it will be +/// followed by the target statement. This may seem obvious, but may imply that +/// a statement associated with, but different from, the source statement is +/// modified. +fn set_ast_statement_next(ctx: &mut Ctx, source_stmt_id: StatementId, target_stmt_id: StatementId) { + let source_stmt = &mut ctx.heap[source_stmt_id]; + match source_stmt { + Statement::Block(stmt) => { + let end_id = stmt.end_block; + ctx.heap[end_id].next = target_stmt_id + }, + Statement::EndBlock(stmt) => stmt.next = target_stmt_id, + Statement::Local(stmt) => { + match stmt { + LocalStatement::Memory(stmt) => stmt.next = target_stmt_id, + LocalStatement::Channel(stmt) => stmt.next = target_stmt_id, + } + }, + Statement::Labeled(stmt) => { + let body_id = stmt.body; + set_ast_statement_next(ctx, body_id, target_stmt_id); + }, + Statement::If(stmt) => { + let end_id = stmt.end_if; + ctx.heap[end_id].next = target_stmt_id; + }, + Statement::EndIf(stmt) => stmt.next = target_stmt_id, + Statement::While(stmt) => { + let end_id = stmt.end_while; + ctx.heap[end_id].next = target_stmt_id; + }, + Statement::EndWhile(stmt) => stmt.next = target_stmt_id, + + Statement::Break(_stmt) => {}, + Statement::Continue(_stmt) => {}, + Statement::Synchronous(stmt) => { + let end_id = stmt.end_sync; + ctx.heap[end_id].next = target_stmt_id; + }, + Statement::EndSynchronous(stmt) => { + stmt.next = target_stmt_id; + }, + Statement::Fork(_) | Statement::EndFork(_) => { + todo!("remove fork from language"); + }, + Statement::Select(stmt) => { + let end_id = stmt.end_select; + ctx.heap[end_id].next = target_stmt_id; + }, + Statement::EndSelect(stmt) => stmt.next = target_stmt_id, + Statement::Return(_stmt) => {}, + Statement::Goto(_stmt) => {}, + Statement::New(stmt) => stmt.next = target_stmt_id, + Statement::Expression(stmt) => stmt.next = target_stmt_id, + } +} + +/// Links a new scope to an existing scope as its child. +fn link_new_child_to_existing_parent_scope(ctx: &mut Ctx, scope_buffer: &mut ScopedBuffer, parent_scope_id: ScopeId, child_scope_id: ScopeId, relative_pos_hint: i32) { + let child_scope = &mut ctx.heap[child_scope_id]; + debug_assert!(child_scope.parent.is_none()); + + child_scope.parent = Some(parent_scope_id); + child_scope.relative_pos_in_parent = relative_pos_hint; + + add_child_scope_to_parent(ctx, scope_buffer, parent_scope_id, child_scope_id, relative_pos_hint); +} + +/// Relinks an existing scope to a new scope as its child. Will also break the +/// link of the child scope's old parent. +fn link_existing_child_to_new_parent_scope(ctx: &mut Ctx, scope_buffer: &mut ScopedBuffer, new_parent_scope_id: ScopeId, child_scope_id: ScopeId, new_relative_pos_in_parent: i32) { + let child_scope = &mut ctx.heap[child_scope_id]; + let old_parent_scope_id = child_scope.parent.unwrap(); + child_scope.parent = Some(new_parent_scope_id); + child_scope.relative_pos_in_parent = new_relative_pos_in_parent; + + // Remove from old parent + let old_parent = &mut ctx.heap[old_parent_scope_id]; + let scope_index = old_parent.nested.iter() + .position(|v| *v == child_scope_id) + .unwrap(); + old_parent.nested.remove(scope_index); + + // Add to new parent + add_child_scope_to_parent(ctx, scope_buffer, new_parent_scope_id, child_scope_id, new_relative_pos_in_parent); +} + +/// Will add a child scope to a parent scope using the relative position hint. +fn add_child_scope_to_parent(ctx: &mut Ctx, scope_buffer: &mut ScopedBuffer, parent_scope_id: ScopeId, child_scope_id: ScopeId, relative_pos_hint: i32) { + let parent_scope = &ctx.heap[parent_scope_id]; + + let existing_scope_ids = scope_buffer.start_section_initialized(&parent_scope.nested); + let mut insert_pos = existing_scope_ids.len(); + for index in 0..existing_scope_ids.len() { + let existing_scope_id = existing_scope_ids[index]; + let existing_scope = &ctx.heap[existing_scope_id]; + if relative_pos_hint <= existing_scope.relative_pos_in_parent { + insert_pos = index; + break; + } + } + existing_scope_ids.forget(); + + let parent_scope = &mut ctx.heap[parent_scope_id]; + parent_scope.nested.insert(insert_pos, child_scope_id); +} + +fn add_new_procedure_expression_type(ctx: &mut Ctx, procedure_id: ProcedureDefinitionId, type_id: TypeIdReference) -> i32 { + let procedure = &mut ctx.heap[procedure_id]; + let type_index = procedure.monomorphs[0].expr_info.len(); + + match type_id { + TypeIdReference::DirectTypeId(type_id) => { + for monomorph in procedure.monomorphs.iter_mut() { + debug_assert_eq!(monomorph.expr_info.len(), type_index); + monomorph.expr_info.push(ExpressionInfo{ + type_id, + variant: ExpressionInfoVariant::Generic + }); + } + }, + TypeIdReference::IndirectSameAsExpr(source_type_index) => { + for monomorph in procedure.monomorphs.iter_mut() { + debug_assert_eq!(monomorph.expr_info.len(), type_index); + let copied_expr_info = monomorph.expr_info[source_type_index as usize]; + monomorph.expr_info.push(copied_expr_info) + } + } + } + + return type_index as i32; +} \ No newline at end of file diff --git a/src/protocol/parser/pass_stack_size.rs b/src/protocol/parser/pass_stack_size.rs new file mode 100644 index 0000000000000000000000000000000000000000..55acf4961cab58a23cf0d9be3fdf185e2be50951 --- /dev/null +++ b/src/protocol/parser/pass_stack_size.rs @@ -0,0 +1,112 @@ +use crate::collections::*; +use crate::protocol::*; + +use super::visitor::*; + +// Will get a rename. Will probably become bytecode emitter or something. For +// now it just scans the scopes and assigns a unique number for each variable +// such that, at any point in the program's execution, all accessible in-scope +// variables will have a unique position "on the stack". +pub(crate) struct PassStackSize { + definition_buffer: ScopedBuffer, + variable_buffer: ScopedBuffer, + scope_buffer: ScopedBuffer, +} + +impl PassStackSize { + pub(crate) fn new() -> Self { + return Self{ + definition_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_LARGE), + variable_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL), + scope_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL), + } + } +} + +impl Visitor for PassStackSize { + // Top level visitors + + fn visit_module(&mut self, ctx: &mut Ctx) -> VisitorResult { + let module = ctx.module(); + debug_assert_eq!(module.phase, ModuleCompilationPhase::Rewritten); + + let root_id = module.root_id; + let root = &ctx.heap[root_id]; + let definition_section = self.definition_buffer.start_section_initialized(&root.definitions); + for definition_index in 0..definition_section.len() { + let definition_id = definition_section[definition_index]; + self.visit_definition(ctx, definition_id)? + } + + definition_section.forget(); + // ctx.module_mut().phase = ModuleCompilationPhase::StackSizeStuffAndStuff; + return Ok(()) + } + + fn visit_procedure_definition(&mut self, ctx: &mut Ctx, id: ProcedureDefinitionId) -> VisitorResult { + let definition = &ctx.heap[id]; + let scope_id = definition.scope; + + self.visit_scope_and_assign_local_ids(ctx, scope_id, 0); + return Ok(()); + } +} + +impl PassStackSize { + fn visit_scope_and_assign_local_ids(&mut self, ctx: &mut Ctx, scope_id: ScopeId, mut variable_counter: i32) { + let scope = &mut ctx.heap[scope_id]; + scope.first_unique_id_in_scope = variable_counter; + + let variable_section = self.variable_buffer.start_section_initialized(&scope.variables); + let child_scope_section = self.scope_buffer.start_section_initialized(&scope.nested); + + let mut variable_index = 0; + let mut child_scope_index = 0; + + loop { + // Determine relative positions of variable and scope to determine + // which one occurs first within the current scope. + let variable_relative_pos; + if variable_index < variable_section.len() { + let variable_id = variable_section[variable_index]; + let variable = &ctx.heap[variable_id]; + variable_relative_pos = variable.relative_pos_in_parent; + } else { + variable_relative_pos = i32::MAX; + } + + let child_scope_relative_pos; + if child_scope_index < child_scope_section.len() { + let child_scope_id = child_scope_section[child_scope_index]; + let child_scope = &ctx.heap[child_scope_id]; + child_scope_relative_pos = child_scope.relative_pos_in_parent; + } else { + child_scope_relative_pos = i32::MAX; + } + + if variable_relative_pos == i32::MAX && child_scope_relative_pos == i32::MAX { + // Done, no more elements in the scope to consider + break; + } + + // Label the variable/scope, whichever comes first. + if variable_relative_pos <= child_scope_relative_pos { + debug_assert_ne!(variable_relative_pos, child_scope_relative_pos, "checking if this ever happens"); + let variable = &mut ctx.heap[variable_section[variable_index]]; + variable.unique_id_in_scope = variable_counter; + variable_counter += 1; + variable_index += 1; + } else { + let child_scope_id = child_scope_section[child_scope_index]; + self.visit_scope_and_assign_local_ids(ctx, child_scope_id, variable_counter); + child_scope_index += 1; + } + } + + variable_section.forget(); + child_scope_section.forget(); + + let scope = &mut ctx.heap[scope_id]; + scope.next_unique_id_in_scope = variable_counter; + } +} \ No newline at end of file diff --git a/src/protocol/parser/pass_symbols.rs b/src/protocol/parser/pass_symbols.rs index 9f48caf8d41d83e2506044c5e009aeb63794dd96..995dfa15100c3f0d1e3f44779257ea150cbeb60e 100644 --- a/src/protocol/parser/pass_symbols.rs +++ b/src/protocol/parser/pass_symbols.rs @@ -230,23 +230,23 @@ impl PassSymbols { ast_definition_id = union_def_id.upcast() }, KW_FUNCTION => { - let func_def_id = ctx.heap.alloc_function_definition(|this| { - FunctionDefinition::new_empty(this, module.root_id, definition_span, identifier, poly_vars) + let proc_def_id = ctx.heap.alloc_procedure_definition(|this| { + ProcedureDefinition::new_empty(this, module.root_id, definition_span, ProcedureKind::Function, identifier, poly_vars) }); definition_class = DefinitionClass::Function; - ast_definition_id = func_def_id.upcast(); + ast_definition_id = proc_def_id.upcast(); }, KW_PRIMITIVE | KW_COMPOSITE => { - let component_variant = if kw_text == KW_PRIMITIVE { - ComponentVariant::Primitive + let procedure_kind = if kw_text == KW_PRIMITIVE { + ProcedureKind::Primitive } else { - ComponentVariant::Composite + ProcedureKind::Composite }; - let comp_def_id = ctx.heap.alloc_component_definition(|this| { - ComponentDefinition::new_empty(this, module.root_id, definition_span, component_variant, identifier, poly_vars) + let proc_def_id = ctx.heap.alloc_procedure_definition(|this| { + ProcedureDefinition::new_empty(this, module.root_id, definition_span, procedure_kind, identifier, poly_vars) }); definition_class = DefinitionClass::Component; - ast_definition_id = comp_def_id.upcast(); + ast_definition_id = proc_def_id.upcast(); }, _ => unreachable!("encountered keyword '{}' in definition range", String::from_utf8_lossy(kw_text)), } diff --git a/src/protocol/parser/pass_typing.rs b/src/protocol/parser/pass_typing.rs index e31e13c896efe7558d80c95de258c845868faeb9..e0a8f8cd4e6f8ef95c3c3668a7c97f4978515729 100644 --- a/src/protocol/parser/pass_typing.rs +++ b/src/protocol/parser/pass_typing.rs @@ -26,13 +26,8 @@ /// instead of with HashMaps is there, but it is not really used because of /// time constraints. When time is available, rewrite the system such that /// AST IDs are not needed, and only indices into arrays are used. -/// 2. We're doing a lot of extra work. It seems better to apply the initial -/// type based on expression parents, and immediately apply forced -/// constraints (arg to a fires() call must be port-like). All of the \ -/// progress_xxx calls should then only be concerned with "transmitting" -/// type inference across their parent/child expressions. -/// 3. Remove the `msg` type? -/// 4. Disallow certain types in certain operations (e.g. `Void`). +/// 2. Remove the `msg` type? +/// 3. Disallow certain types in certain operations (e.g. `Void`). macro_rules! debug_log_enabled { () => { false }; @@ -47,7 +42,7 @@ macro_rules! debug_log { }; } -use std::collections::{HashMap, HashSet}; +use std::collections::VecDeque; use crate::collections::{ScopedBuffer, ScopedSection, DequeSet}; use crate::protocol::ast::*; @@ -56,12 +51,15 @@ use crate::protocol::parser::ModuleCompilationPhase; use crate::protocol::parser::type_table::*; use crate::protocol::parser::token_parsing::*; use super::visitor::{ - BUFFER_INIT_CAPACITY, + BUFFER_INIT_CAP_LARGE, + BUFFER_INIT_CAP_SMALL, Ctx, - Visitor, - VisitorResult }; +// ----------------------------------------------------------------------------- +// Inference type +// ----------------------------------------------------------------------------- + const VOID_TEMPLATE: [InferenceTypePart; 1] = [ InferenceTypePart::Void ]; const MESSAGE_TEMPLATE: [InferenceTypePart; 2] = [ InferenceTypePart::Message, InferenceTypePart::UInt8 ]; const BOOL_TEMPLATE: [InferenceTypePart; 1] = [ InferenceTypePart::Bool ]; @@ -223,13 +221,13 @@ impl InferenceType { /// Generates a new InferenceType. The two boolean flags will be checked in /// debug mode. fn new(has_marker: bool, is_done: bool, parts: Vec) -> Self { - if cfg!(debug_assertions) { + dbg_code!({ debug_assert!(!parts.is_empty()); let parts_body_marker = parts.iter().any(|v| v.is_marker()); debug_assert_eq!(has_marker, parts_body_marker); let parts_done = parts.iter().all(|v| v.is_concrete()); debug_assert_eq!(is_done, parts_done, "{:?}", parts); - } + }); Self{ has_marker, is_done, parts } } @@ -806,19 +804,13 @@ enum SingleInferenceResult { Incompatible } -enum DefinitionType{ - Component(ComponentDefinitionId), - Function(FunctionDefinitionId), -} +// ----------------------------------------------------------------------------- +// PassTyping - Public Interface +// ----------------------------------------------------------------------------- -impl DefinitionType { - fn definition_id(&self) -> DefinitionId { - match self { - DefinitionType::Component(v) => v.upcast(), - DefinitionType::Function(v) => v.upcast(), - } - } -} +type InferNodeIndex = usize; +type PolyDataIndex = isize; +type VarDataIndex = usize; pub(crate) struct ResolveQueueElement { // Note that using the `definition_id` and the `monomorph_idx` one may @@ -826,147 +818,345 @@ pub(crate) struct ResolveQueueElement { // the polymorphic arguments to the procedure. pub(crate) root_id: RootId, pub(crate) definition_id: DefinitionId, - pub(crate) reserved_monomorph_idx: i32, + pub(crate) reserved_type_id: TypeId, + pub(crate) reserved_monomorph_index: u32, +} + +pub(crate) type ResolveQueue = VecDeque; + +struct InferenceNode { + // filled in during type inference + expr_type: InferenceType, // result type from expression + expr_id: ExpressionId, // expression that is evaluated + inference_rule: InferenceRule, // rule used to infer node type + parent_index: Option, // parent of inference node + field_index: i32, // index of struct field or tuple member + poly_data_index: PolyDataIndex, // index to inference data for polymorphic types + // filled in once type inference is done + info_type_id: TypeId, + info_variant: ExpressionInfoVariant, +} + +impl InferenceNode { + #[inline] + fn as_expression_info(&self) -> ExpressionInfo { + return ExpressionInfo { + type_id: self.info_type_id, + variant: self.info_variant + } + } } -pub(crate) type ResolveQueue = Vec; +/// Inferencing rule to apply. Some of these are reasonably generic. Other ones +/// require so much custom logic that we'll not try to come up with an +/// abstraction. +enum InferenceRule { + Noop, + MonoTemplate(InferenceRuleTemplate), + BiEqual(InferenceRuleBiEqual), + TriEqualArgs(InferenceRuleTriEqualArgs), + TriEqualAll(InferenceRuleTriEqualAll), + Concatenate(InferenceRuleTwoArgs), + IndexingExpr(InferenceRuleIndexingExpr), + SlicingExpr(InferenceRuleSlicingExpr), + SelectStructField(InferenceRuleSelectStructField), + SelectTupleMember(InferenceRuleSelectTupleMember), + LiteralStruct(InferenceRuleLiteralStruct), + LiteralEnum, + LiteralUnion(InferenceRuleLiteralUnion), + LiteralArray(InferenceRuleLiteralArray), + LiteralTuple(InferenceRuleLiteralTuple), + CastExpr(InferenceRuleCastExpr), + CallExpr(InferenceRuleCallExpr), + VariableExpr(InferenceRuleVariableExpr), +} -#[derive(Clone)] -struct InferenceExpression { - expr_type: InferenceType, // result type from expression - expr_id: ExpressionId, // expression that is evaluated - field_or_monomorph_idx: i32, // index of field, of index of monomorph array in type table - extra_data_idx: i32, // index of extra data needed for inference +impl InferenceRule { + union_cast_to_ref_method_impl!(as_mono_template, InferenceRuleTemplate, InferenceRule::MonoTemplate); + union_cast_to_ref_method_impl!(as_bi_equal, InferenceRuleBiEqual, InferenceRule::BiEqual); + union_cast_to_ref_method_impl!(as_tri_equal_args, InferenceRuleTriEqualArgs, InferenceRule::TriEqualArgs); + union_cast_to_ref_method_impl!(as_tri_equal_all, InferenceRuleTriEqualAll, InferenceRule::TriEqualAll); + union_cast_to_ref_method_impl!(as_concatenate, InferenceRuleTwoArgs, InferenceRule::Concatenate); + union_cast_to_ref_method_impl!(as_indexing_expr, InferenceRuleIndexingExpr, InferenceRule::IndexingExpr); + union_cast_to_ref_method_impl!(as_slicing_expr, InferenceRuleSlicingExpr, InferenceRule::SlicingExpr); + union_cast_to_ref_method_impl!(as_select_struct_field, InferenceRuleSelectStructField, InferenceRule::SelectStructField); + union_cast_to_ref_method_impl!(as_select_tuple_member, InferenceRuleSelectTupleMember, InferenceRule::SelectTupleMember); + union_cast_to_ref_method_impl!(as_literal_struct, InferenceRuleLiteralStruct, InferenceRule::LiteralStruct); + union_cast_to_ref_method_impl!(as_literal_union, InferenceRuleLiteralUnion, InferenceRule::LiteralUnion); + union_cast_to_ref_method_impl!(as_literal_array, InferenceRuleLiteralArray, InferenceRule::LiteralArray); + union_cast_to_ref_method_impl!(as_literal_tuple, InferenceRuleLiteralTuple, InferenceRule::LiteralTuple); + union_cast_to_ref_method_impl!(as_cast_expr, InferenceRuleCastExpr, InferenceRule::CastExpr); + union_cast_to_ref_method_impl!(as_call_expr, InferenceRuleCallExpr, InferenceRule::CallExpr); + union_cast_to_ref_method_impl!(as_variable_expr, InferenceRuleVariableExpr, InferenceRule::VariableExpr); } -impl Default for InferenceExpression { - fn default() -> Self { - Self{ - expr_type: InferenceType::default(), - expr_id: ExpressionId::new_invalid(), - field_or_monomorph_idx: -1, - extra_data_idx: -1, +// Note: InferenceRuleTemplate is `Copy`, so don't add dynamically allocated +// members in the future (or review places where this struct is copied) +#[derive(Clone, Copy)] +struct InferenceRuleTemplate { + template: &'static [InferenceTypePart], + application: InferenceRuleTemplateApplication, +} + +impl InferenceRuleTemplate { + fn new_none() -> Self { + return Self{ + template: &[], + application: InferenceRuleTemplateApplication::None, + }; + } + + fn new_forced(template: &'static [InferenceTypePart]) -> Self { + return Self{ + template, + application: InferenceRuleTemplateApplication::Forced, + }; + } + + fn new_template(template: &'static [InferenceTypePart]) -> Self { + return Self{ + template, + application: InferenceRuleTemplateApplication::Template, } } } +#[derive(Clone, Copy)] +enum InferenceRuleTemplateApplication { + None, // do not apply template, silly, but saves some bytes + Forced, + Template, +} + +/// Type equality applied to 'self' and the argument. An optional template will +/// be applied to 'self' first. Example: "bitwise not" +struct InferenceRuleBiEqual { + template: InferenceRuleTemplate, + argument_index: InferNodeIndex, +} + +/// Type equality applied to two arguments. Template can be applied to 'self' +/// (generally forced, since this rule does not apply a type equality constraint +/// to 'self') and the two arguments. Example: "equality operator" +struct InferenceRuleTriEqualArgs { + argument_template: InferenceRuleTemplate, + result_template: InferenceRuleTemplate, + argument1_index: InferNodeIndex, + argument2_index: InferNodeIndex, +} + +/// Type equality applied to 'self' and two arguments. Template may be +/// optionally applied to 'self'. Example: "addition operator" +struct InferenceRuleTriEqualAll { + template: InferenceRuleTemplate, + argument1_index: InferNodeIndex, + argument2_index: InferNodeIndex, +} + +/// Information for an inference rule that is applied to 'self' and two +/// arguments, see `InferenceRule` for its meaning. +struct InferenceRuleTwoArgs { + argument1_index: InferNodeIndex, + argument2_index: InferNodeIndex, +} + +struct InferenceRuleIndexingExpr { + subject_index: InferNodeIndex, + index_index: InferNodeIndex, +} + +struct InferenceRuleSlicingExpr { + subject_index: InferNodeIndex, + from_index: InferNodeIndex, + to_index: InferNodeIndex, +} + +struct InferenceRuleSelectStructField { + subject_index: InferNodeIndex, + selected_field: Identifier, +} + +struct InferenceRuleSelectTupleMember { + subject_index: InferNodeIndex, + selected_index: u64, +} + +struct InferenceRuleLiteralStruct { + element_indices: Vec, +} + +struct InferenceRuleLiteralUnion { + element_indices: Vec +} + +struct InferenceRuleLiteralArray { + element_indices: Vec +} + +struct InferenceRuleLiteralTuple { + element_indices: Vec +} + +struct InferenceRuleCastExpr { + subject_index: InferNodeIndex, +} + +struct InferenceRuleCallExpr { + argument_indices: Vec +} + +/// Data associated with a variable expression: an expression that reads the +/// value from a variable. +struct InferenceRuleVariableExpr { + var_data_index: VarDataIndex, // shared variable information +} + /// This particular visitor will recurse depth-first into the AST and ensures /// that all expressions have the appropriate types. pub(crate) struct PassTyping { // Current definition we're typechecking. - reserved_idx: i32, - definition_type: DefinitionType, + reserved_type_id: TypeId, + reserved_monomorph_index: u32, + procedure_id: ProcedureDefinitionId, + procedure_kind: ProcedureKind, poly_vars: Vec, + // Temporary variables during construction of inference rulesr + parent_index: Option, // Buffers for iteration over various types var_buffer: ScopedBuffer, expr_buffer: ScopedBuffer, stmt_buffer: ScopedBuffer, bool_buffer: ScopedBuffer, + index_buffer: ScopedBuffer, + definition_buffer: ScopedBuffer, + poly_progress_buffer: ScopedBuffer, // Mapping from parser type to inferred type. We attempt to continue to // specify these types until we're stuck or we've fully determined the type. - var_types: HashMap, // types of variables - expr_types: Vec, // will be transferred to type table at end - extra_data: Vec, // data for polymorph inference + infer_nodes: Vec, // will be transferred to type table at end + poly_data: Vec, // data for polymorph inference + var_data: Vec, // Keeping track of which expressions need to be reinferred because the // expressions they're linked to made progression on an associated type - expr_queued: DequeSet, + node_queued: DequeSet, } -// TODO: @Rename, this is used for a lot of type inferencing. It seems like -// there is a different underlying architecture waiting to surface. -struct ExtraData { - expr_id: ExpressionId, // the expression with which this data is associated +/// Generic struct that is used to store inferred types associated with +/// polymorphic types. +struct PolyData { + first_rule_application: bool, definition_id: DefinitionId, // the definition, only used for user feedback - /// Progression of polymorphic variables (if any) + /// Inferred types of the polymorphic variables as they are written down + /// at the type's definition. poly_vars: Vec, - /// Progression of types of call arguments or struct members - embedded: Vec, + expr_types: PolyDataTypes, +} + +// silly structure, just so we can use `PolyDataTypeIndex` ergonomically while +// making sure we're still capable of borrowing from `poly_vars`. +struct PolyDataTypes { + /// Inferred types of associated types (e.g. struct fields, tuple members, + /// function arguments). These types may depend on the polymorphic variables + /// defined above. + associated: Vec, + /// Inferred "returned" type (e.g. if a struct field is selected, then this + /// contains the type of the selected field, for a function call it contains + /// the return type). May depend on the polymorphic variables defined above. returned: InferenceType, } -impl Default for ExtraData { - fn default() -> Self { - Self{ - expr_id: ExpressionId::new_invalid(), - definition_id: DefinitionId::new_invalid(), - poly_vars: Vec::new(), - embedded: Vec::new(), - returned: InferenceType::default(), +#[derive(Clone, Copy)] +enum PolyDataTypeIndex { + Associated(usize), // indexes into `PolyData.associated` + Returned, +} + +impl PolyDataTypes { + fn get_type(&self, index: PolyDataTypeIndex) -> &InferenceType { + match index { + PolyDataTypeIndex::Associated(index) => return &self.associated[index], + PolyDataTypeIndex::Returned => return &self.returned, + } + } + + fn get_type_mut(&mut self, index: PolyDataTypeIndex) -> &mut InferenceType { + match index { + PolyDataTypeIndex::Associated(index) => return &mut self.associated[index], + PolyDataTypeIndex::Returned => return &mut self.returned, } } } struct VarData { - /// Type of the variable + var_id: VariableId, var_type: InferenceType, - /// VariableExpressions that use the variable - used_at: Vec, - /// For channel statements we link to the other variable such that when one - /// channel's interior type is resolved, we can also resolve the other one. - linked_var: Option, -} - -impl VarData { - fn new_channel(var_type: InferenceType, other_port: VariableId) -> Self { - Self{ var_type, used_at: Vec::new(), linked_var: Some(other_port) } - } - fn new_local(var_type: InferenceType) -> Self { - Self{ var_type, used_at: Vec::new(), linked_var: None } - } + used_at: Vec, // of variable expressions + linked_var: Option, } impl PassTyping { pub(crate) fn new() -> Self { PassTyping { - reserved_idx: -1, - definition_type: DefinitionType::Function(FunctionDefinitionId::new_invalid()), + reserved_type_id: TypeId::new_invalid(), + reserved_monomorph_index: u32::MAX, + procedure_id: ProcedureDefinitionId::new_invalid(), + procedure_kind: ProcedureKind::Function, poly_vars: Vec::new(), - var_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAPACITY), - expr_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAPACITY), - stmt_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAPACITY), - bool_buffer: ScopedBuffer::with_capacity(16), - var_types: HashMap::new(), - expr_types: Vec::new(), - extra_data: Vec::new(), - expr_queued: DequeSet::new(), + parent_index: None, + var_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_LARGE), + expr_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_LARGE), + stmt_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_LARGE), + bool_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL), + index_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL), + definition_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_LARGE), + poly_progress_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL), + infer_nodes: Vec::with_capacity(BUFFER_INIT_CAP_LARGE), + poly_data: Vec::with_capacity(BUFFER_INIT_CAP_SMALL), + var_data: Vec::with_capacity(BUFFER_INIT_CAP_SMALL), + node_queued: DequeSet::new(), } } - pub(crate) fn queue_module_definitions(ctx: &mut Ctx, queue: &mut ResolveQueue) { + pub(crate) fn queue_module_definitions(&mut self, ctx: &mut Ctx, queue: &mut ResolveQueue) { debug_assert_eq!(ctx.module().phase, ModuleCompilationPhase::ValidatedAndLinked); let root_id = ctx.module().root_id; let root = &ctx.heap.protocol_descriptions[root_id]; - for definition_id in &root.definitions { - let definition = &ctx.heap[*definition_id]; + let definitions_section = self.definition_buffer.start_section_initialized(&root.definitions); + + for definition_id in definitions_section.iter_copied() { + let definition = &ctx.heap[definition_id]; - let first_concrete_part = match definition { - Definition::Function(definition) => { + let first_concrete_part_and_procedure_id = match definition { + Definition::Procedure(definition) => { if definition.poly_vars.is_empty() { - Some(ConcreteTypePart::Function(*definition_id, 0)) + if definition.kind == ProcedureKind::Function { + Some((ConcreteTypePart::Function(definition.this, 0), definition.this)) + } else { + Some((ConcreteTypePart::Component(definition.this, 0), definition.this)) + } } else { None } } - Definition::Component(definition) => { - if definition.poly_vars.is_empty() { - Some(ConcreteTypePart::Component(*definition_id, 0)) - } else { - None - } - }, Definition::Enum(_) | Definition::Struct(_) | Definition::Union(_) => None, }; - if let Some(first_concrete_part) = first_concrete_part { + if let Some((first_concrete_part, procedure_id)) = first_concrete_part_and_procedure_id { + let procedure = &mut ctx.heap[procedure_id]; + let monomorph_index = procedure.monomorphs.len() as u32; + procedure.monomorphs.push(ProcedureDefinitionMonomorph::new_invalid()); + let concrete_type = ConcreteType{ parts: vec![first_concrete_part] }; - let reserved_idx = ctx.types.reserve_procedure_monomorph_index(definition_id, concrete_type); - queue.push(ResolveQueueElement{ + let type_id = ctx.types.reserve_procedure_monomorph_type_id(&definition_id, concrete_type, monomorph_index); + queue.push_back(ResolveQueueElement{ root_id, - definition_id: *definition_id, - reserved_monomorph_idx: reserved_idx, + definition_id, + reserved_type_id: type_id, + reserved_monomorph_index: monomorph_index, }) } } + + definitions_section.forget(); } pub(crate) fn handle_module_definition( @@ -977,11 +1167,12 @@ impl PassTyping { debug_assert!(self.poly_vars.is_empty()); // Prepare for visiting the definition - self.reserved_idx = element.reserved_monomorph_idx; + self.reserved_type_id = element.reserved_type_id; + self.reserved_monomorph_index = element.reserved_monomorph_index; let proc_base = ctx.types.get_base_definition(&element.definition_id).unwrap(); if proc_base.is_polymorph { - let monomorph = ctx.types.get_monomorph(element.reserved_monomorph_idx); + let monomorph = ctx.types.get_monomorph(element.reserved_type_id); for poly_arg in monomorph.concrete_type.embedded_iter(0) { self.poly_vars.push(ConcreteType{ parts: Vec::from(poly_arg) }); } @@ -995,90 +1186,74 @@ impl PassTyping { } fn reset(&mut self) { - self.reserved_idx = -1; - self.definition_type = DefinitionType::Function(FunctionDefinitionId::new_invalid()); + self.reserved_type_id = TypeId::new_invalid(); + self.procedure_id = ProcedureDefinitionId::new_invalid(); + self.procedure_kind = ProcedureKind::Function; self.poly_vars.clear(); - self.var_types.clear(); - self.expr_types.clear(); - self.extra_data.clear(); - self.expr_queued.clear(); + self.parent_index = None; + + self.infer_nodes.clear(); + self.poly_data.clear(); + self.var_data.clear(); + self.node_queued.clear(); } } -impl Visitor for PassTyping { - // Definitions - - fn visit_component_definition(&mut self, ctx: &mut Ctx, id: ComponentDefinitionId) -> VisitorResult { - self.definition_type = DefinitionType::Component(id); - - let comp_def = &ctx.heap[id]; - debug_assert_eq!(comp_def.poly_vars.len(), self.poly_vars.len(), "component polyvars do not match imposed polyvars"); +// ----------------------------------------------------------------------------- +// PassTyping - Visitor-like implementation +// ----------------------------------------------------------------------------- - debug_log!("{}", "-".repeat(50)); - debug_log!("Visiting component '{}': {}", comp_def.identifier.value.as_str(), id.0.index); - debug_log!("{}", "-".repeat(50)); +type VisitorResult = Result<(), ParseError>; +type VisitExprResult = Result; - // Reserve data for expression types - debug_assert!(self.expr_types.is_empty()); - self.expr_types.resize(comp_def.num_expressions_in_body as usize, Default::default()); - - // Visit parameters - let section = self.var_buffer.start_section_initialized(comp_def.parameters.as_slice()); - for param_id in section.iter_copied() { - let param = &ctx.heap[param_id]; - let var_type = self.determine_inference_type_from_parser_type_elements(¶m.parser_type.elements, true); - debug_assert!(var_type.is_done, "expected component arguments to be concrete types"); - self.var_types.insert(param_id, VarData::new_local(var_type)); - } - section.forget(); +impl PassTyping { + // Definitions - // Visit the body and all of its expressions - let body_stmt_id = ctx.heap[id].body; - self.visit_block_stmt(ctx, body_stmt_id) + fn visit_definition(&mut self, ctx: &mut Ctx, id: DefinitionId) -> VisitorResult { + return visitor_recursive_definition_impl!(self, &ctx.heap[id], ctx); } - fn visit_function_definition(&mut self, ctx: &mut Ctx, id: FunctionDefinitionId) -> VisitorResult { - self.definition_type = DefinitionType::Function(id); + fn visit_enum_definition(&mut self, _: &mut Ctx, _: EnumDefinitionId) -> VisitorResult { return Ok(()) } + fn visit_struct_definition(&mut self, _: &mut Ctx, _: StructDefinitionId) -> VisitorResult { return Ok(()) } + fn visit_union_definition(&mut self, _: &mut Ctx, _: UnionDefinitionId) -> VisitorResult { return Ok(()) } + + fn visit_procedure_definition(&mut self, ctx: &mut Ctx, id: ProcedureDefinitionId) -> VisitorResult { + let procedure_def = &ctx.heap[id]; - let func_def = &ctx.heap[id]; - debug_assert_eq!(func_def.poly_vars.len(), self.poly_vars.len(), "function polyvars do not match imposed polyvars"); + self.procedure_id = id; + self.procedure_kind = procedure_def.kind; + let body_id = procedure_def.body; debug_log!("{}", "-".repeat(50)); - debug_log!("Visiting function '{}': {}", func_def.identifier.value.as_str(), id.0.index); - if debug_log_enabled!() { - debug_log!("Polymorphic variables:"); - for (_idx, poly_var) in self.poly_vars.iter().enumerate() { - let mut infer_type_parts = Vec::new(); - Self::determine_inference_type_from_concrete_type( - &mut infer_type_parts, &poly_var.parts - ); - let _infer_type = InferenceType::new(false, true, infer_type_parts); - debug_log!(" - [{:03}] {:?}", _idx, _infer_type.display_name(&ctx.heap)); - } - } + debug_log!("Visiting procedure: '{}' (id: {}, kind: {:?})", procedure_def.identifier.value.as_str(), id.0.index, procedure_def.kind); debug_log!("{}", "-".repeat(50)); - // Reserve data for expression types - debug_assert!(self.expr_types.is_empty()); - self.expr_types.resize(func_def.num_expressions_in_body as usize, Default::default()); - // Visit parameters - let section = self.var_buffer.start_section_initialized(func_def.parameters.as_slice()); + let section = self.var_buffer.start_section_initialized(procedure_def.parameters.as_slice()); for param_id in section.iter_copied() { let param = &ctx.heap[param_id]; let var_type = self.determine_inference_type_from_parser_type_elements(¶m.parser_type.elements, true); debug_assert!(var_type.is_done, "expected function arguments to be concrete types"); - self.var_types.insert(param_id, VarData::new_local(var_type)); + self.var_data.push(VarData{ + var_id: param_id, + var_type, + used_at: Vec::new(), + linked_var: None + }) } section.forget(); // Visit all of the expressions within the body - let body_stmt_id = ctx.heap[id].body; - self.visit_block_stmt(ctx, body_stmt_id) + self.parent_index = None; + return self.visit_block_stmt(ctx, body_id); } // Statements + fn visit_stmt(&mut self, ctx: &mut Ctx, id: StatementId) -> VisitorResult { + return visitor_recursive_statement_impl!(self, &ctx.heap[id], ctx, Ok(())); + } + fn visit_block_stmt(&mut self, ctx: &mut Ctx, id: BlockStatementId) -> VisitorResult { // Transfer statements for traversal let block = &ctx.heap[id]; @@ -1092,14 +1267,22 @@ impl Visitor for PassTyping { Ok(()) } + fn visit_local_stmt(&mut self, ctx: &mut Ctx, id: LocalStatementId) -> VisitorResult { + return visitor_recursive_local_impl!(self, &ctx.heap[id], ctx); + } + fn visit_local_memory_stmt(&mut self, ctx: &mut Ctx, id: MemoryStatementId) -> VisitorResult { let memory_stmt = &ctx.heap[id]; let initial_expr_id = memory_stmt.initial_expr; - // Setup memory statement inference let local = &ctx.heap[memory_stmt.variable]; let var_type = self.determine_inference_type_from_parser_type_elements(&local.parser_type.elements, true); - self.var_types.insert(memory_stmt.variable, VarData::new_local(var_type)); + self.var_data.push(VarData{ + var_id: memory_stmt.variable, + var_type, + used_at: Vec::new(), + linked_var: None, + }); // Process the initial value self.visit_assignment_expr(ctx, initial_expr_id)?; @@ -1110,13 +1293,26 @@ impl Visitor for PassTyping { fn visit_local_channel_stmt(&mut self, ctx: &mut Ctx, id: ChannelStatementId) -> VisitorResult { let channel_stmt = &ctx.heap[id]; + let from_var_index = self.var_data.len() as VarDataIndex; + let to_var_index = from_var_index + 1; + let from_local = &ctx.heap[channel_stmt.from]; let from_var_type = self.determine_inference_type_from_parser_type_elements(&from_local.parser_type.elements, true); - self.var_types.insert(from_local.this, VarData::new_channel(from_var_type, channel_stmt.to)); + self.var_data.push(VarData{ + var_id: channel_stmt.from, + var_type: from_var_type, + used_at: Vec::new(), + linked_var: Some(to_var_index), + }); let to_local = &ctx.heap[channel_stmt.to]; let to_var_type = self.determine_inference_type_from_parser_type_elements(&to_local.parser_type.elements, true); - self.var_types.insert(to_local.this, VarData::new_channel(to_var_type, channel_stmt.from)); + self.var_data.push(VarData{ + var_id: channel_stmt.to, + var_type: to_var_type, + used_at: Vec::new(), + linked_var: Some(from_var_index), + }); Ok(()) } @@ -1130,14 +1326,14 @@ impl Visitor for PassTyping { fn visit_if_stmt(&mut self, ctx: &mut Ctx, id: IfStatementId) -> VisitorResult { let if_stmt = &ctx.heap[id]; - let true_body_id = if_stmt.true_body; - let false_body_id = if_stmt.false_body; + let true_body_case = if_stmt.true_case; + let false_body_case = if_stmt.false_case; let test_expr_id = if_stmt.test; self.visit_expr(ctx, test_expr_id)?; - self.visit_block_stmt(ctx, true_body_id)?; - if let Some(false_body_id) = false_body_id { - self.visit_block_stmt(ctx, false_body_id)?; + self.visit_stmt(ctx, true_body_case.body)?; + if let Some(false_body_case) = false_body_case { + self.visit_stmt(ctx, false_body_case.body)?; } Ok(()) @@ -1150,16 +1346,19 @@ impl Visitor for PassTyping { let test_expr_id = while_stmt.test; self.visit_expr(ctx, test_expr_id)?; - self.visit_block_stmt(ctx, body_id)?; + self.visit_stmt(ctx, body_id)?; Ok(()) } + fn visit_break_stmt(&mut self, _: &mut Ctx, _: BreakStatementId) -> VisitorResult { return Ok(()) } + fn visit_continue_stmt(&mut self, _: &mut Ctx, _: ContinueStatementId) -> VisitorResult { return Ok(()) } + fn visit_synchronous_stmt(&mut self, ctx: &mut Ctx, id: SynchronousStatementId) -> VisitorResult { let sync_stmt = &ctx.heap[id]; let body_id = sync_stmt.body; - self.visit_block_stmt(ctx, body_id) + self.visit_stmt(ctx, body_id) } fn visit_fork_stmt(&mut self, ctx: &mut Ctx, id: ForkStatementId) -> VisitorResult { @@ -1167,9 +1366,9 @@ impl Visitor for PassTyping { let left_body_id = fork_stmt.left_body; let right_body_id = fork_stmt.right_body; - self.visit_block_stmt(ctx, left_body_id)?; + self.visit_stmt(ctx, left_body_id)?; if let Some(right_body_id) = right_body_id { - self.visit_block_stmt(ctx, right_body_id)?; + self.visit_stmt(ctx, right_body_id)?; } Ok(()) @@ -1183,7 +1382,7 @@ impl Visitor for PassTyping { for case in &select_stmt.cases { section.push(case.guard); - section.push(case.block.upcast()); + section.push(case.body); } for case_index in 0..num_cases { @@ -1204,284 +1403,569 @@ impl Visitor for PassTyping { debug_assert_eq!(return_stmt.expressions.len(), 1); let expr_id = return_stmt.expressions[0]; - self.visit_expr(ctx, expr_id) + self.visit_expr(ctx, expr_id)?; + return Ok(()); } + fn visit_goto_stmt(&mut self, _: &mut Ctx, _: GotoStatementId) -> VisitorResult { return Ok(()) } + fn visit_new_stmt(&mut self, ctx: &mut Ctx, id: NewStatementId) -> VisitorResult { let new_stmt = &ctx.heap[id]; let call_expr_id = new_stmt.expression; - self.visit_call_expr(ctx, call_expr_id) + self.visit_call_expr(ctx, call_expr_id)?; + return Ok(()); } fn visit_expr_stmt(&mut self, ctx: &mut Ctx, id: ExpressionStatementId) -> VisitorResult { let expr_stmt = &ctx.heap[id]; let subexpr_id = expr_stmt.expression; - self.visit_expr(ctx, subexpr_id) + self.visit_expr(ctx, subexpr_id)?; + return Ok(()); } // Expressions - fn visit_assignment_expr(&mut self, ctx: &mut Ctx, id: AssignmentExpressionId) -> VisitorResult { + fn visit_expr(&mut self, ctx: &mut Ctx, id: ExpressionId) -> VisitExprResult { + return visitor_recursive_expression_impl!(self, &ctx.heap[id], ctx); + } + + fn visit_assignment_expr(&mut self, ctx: &mut Ctx, id: AssignmentExpressionId) -> VisitExprResult { + use AssignmentOperator as AO; + let upcast_id = id.upcast(); - self.insert_initial_expr_inference_type(ctx, upcast_id)?; + let self_index = self.insert_initial_inference_node(ctx, upcast_id)?; let assign_expr = &ctx.heap[id]; + let assign_op = assign_expr.operation; let left_expr_id = assign_expr.left; let right_expr_id = assign_expr.right; - self.visit_expr(ctx, left_expr_id)?; - self.visit_expr(ctx, right_expr_id)?; + let old_parent = self.parent_index.replace(self_index); + let left_index = self.visit_expr(ctx, left_expr_id)?; + let right_index = self.visit_expr(ctx, right_expr_id)?; + + let node = &mut self.infer_nodes[self_index]; + let argument_template = match assign_op { + AO::Set => + InferenceRuleTemplate::new_none(), + AO::Concatenated => + InferenceRuleTemplate::new_template(&ARRAYLIKE_TEMPLATE), + AO::Multiplied | AO::Divided | AO::Added | AO::Subtracted => + InferenceRuleTemplate::new_template(&NUMBERLIKE_TEMPLATE), + AO::Remained | AO::ShiftedLeft | AO::ShiftedRight | + AO::BitwiseAnded | AO::BitwiseXored | AO::BitwiseOred => + InferenceRuleTemplate::new_template(&INTEGERLIKE_TEMPLATE), + }; + + node.inference_rule = InferenceRule::TriEqualArgs(InferenceRuleTriEqualArgs{ + argument_template, + result_template: InferenceRuleTemplate::new_forced(&VOID_TEMPLATE), + argument1_index: left_index, + argument2_index: right_index, + }); - self.progress_assignment_expr(ctx, id) + self.parent_index = old_parent; + self.progress_inference_rule_tri_equal_args(ctx, self_index)?; + return Ok(self_index); } - fn visit_binding_expr(&mut self, ctx: &mut Ctx, id: BindingExpressionId) -> VisitorResult { + fn visit_binding_expr(&mut self, ctx: &mut Ctx, id: BindingExpressionId) -> VisitExprResult { let upcast_id = id.upcast(); - self.insert_initial_expr_inference_type(ctx, upcast_id)?; + let self_index = self.insert_initial_inference_node(ctx, upcast_id)?; let binding_expr = &ctx.heap[id]; let bound_to_id = binding_expr.bound_to; let bound_from_id = binding_expr.bound_from; - self.visit_expr(ctx, bound_to_id)?; - self.visit_expr(ctx, bound_from_id)?; - - self.progress_binding_expr(ctx, id) + let old_parent = self.parent_index.replace(self_index); + let arg_to_index = self.visit_expr(ctx, bound_to_id)?; + let arg_from_index = self.visit_expr(ctx, bound_from_id)?; + + let node = &mut self.infer_nodes[self_index]; + node.inference_rule = InferenceRule::TriEqualArgs(InferenceRuleTriEqualArgs{ + argument_template: InferenceRuleTemplate::new_none(), + result_template: InferenceRuleTemplate::new_forced(&BOOL_TEMPLATE), + argument1_index: arg_to_index, + argument2_index: arg_from_index, + }); + + self.parent_index = old_parent; + self.progress_inference_rule_tri_equal_args(ctx, self_index)?; + return Ok(self_index); } - fn visit_conditional_expr(&mut self, ctx: &mut Ctx, id: ConditionalExpressionId) -> VisitorResult { + fn visit_conditional_expr(&mut self, ctx: &mut Ctx, id: ConditionalExpressionId) -> VisitExprResult { let upcast_id = id.upcast(); - self.insert_initial_expr_inference_type(ctx, upcast_id)?; + let self_index = self.insert_initial_inference_node(ctx, upcast_id)?; let conditional_expr = &ctx.heap[id]; let test_expr_id = conditional_expr.test; let true_expr_id = conditional_expr.true_expression; let false_expr_id = conditional_expr.false_expression; + let old_parent = self.parent_index.replace(self_index); self.visit_expr(ctx, test_expr_id)?; - self.visit_expr(ctx, true_expr_id)?; - self.visit_expr(ctx, false_expr_id)?; - - self.progress_conditional_expr(ctx, id) + let true_index = self.visit_expr(ctx, true_expr_id)?; + let false_index = self.visit_expr(ctx, false_expr_id)?; + + // Note: the test to the conditional expression has already been forced + // to the boolean type. So the only thing we need to do while progressing + // is to apply an equal3 constraint to the arguments and the result of + // the expression. + let node = &mut self.infer_nodes[self_index]; + node.inference_rule = InferenceRule::TriEqualAll(InferenceRuleTriEqualAll{ + template: InferenceRuleTemplate::new_none(), + argument1_index: true_index, + argument2_index: false_index, + }); + + self.parent_index = old_parent; + self.progress_inference_rule_tri_equal_all(ctx, self_index)?; + return Ok(self_index); } - fn visit_binary_expr(&mut self, ctx: &mut Ctx, id: BinaryExpressionId) -> VisitorResult { + fn visit_binary_expr(&mut self, ctx: &mut Ctx, id: BinaryExpressionId) -> VisitExprResult { + use BinaryOperator as BO; + let upcast_id = id.upcast(); - self.insert_initial_expr_inference_type(ctx, upcast_id)?; + let self_index = self.insert_initial_inference_node(ctx, upcast_id)?; let binary_expr = &ctx.heap[id]; + let binary_op = binary_expr.operation; let lhs_expr_id = binary_expr.left; let rhs_expr_id = binary_expr.right; - self.visit_expr(ctx, lhs_expr_id)?; - self.visit_expr(ctx, rhs_expr_id)?; + let old_parent = self.parent_index.replace(self_index); + let left_index = self.visit_expr(ctx, lhs_expr_id)?; + let right_index = self.visit_expr(ctx, rhs_expr_id)?; + + let inference_rule = match binary_op { + BO::Concatenate => + InferenceRule::Concatenate(InferenceRuleTwoArgs{ + argument1_index: left_index, + argument2_index: right_index, + }), + BO::LogicalAnd | BO::LogicalOr => + InferenceRule::TriEqualAll(InferenceRuleTriEqualAll{ + template: InferenceRuleTemplate::new_forced(&BOOL_TEMPLATE), + argument1_index: left_index, + argument2_index: right_index, + }), + BO::BitwiseOr | BO::BitwiseXor | BO::BitwiseAnd | BO::Remainder | BO::ShiftLeft | BO::ShiftRight => + InferenceRule::TriEqualAll(InferenceRuleTriEqualAll{ + template: InferenceRuleTemplate::new_template(&INTEGERLIKE_TEMPLATE), + argument1_index: left_index, + argument2_index: right_index, + }), + BO::Equality | BO::Inequality => + InferenceRule::TriEqualArgs(InferenceRuleTriEqualArgs{ + argument_template: InferenceRuleTemplate::new_none(), + result_template: InferenceRuleTemplate::new_forced(&BOOL_TEMPLATE), + argument1_index: left_index, + argument2_index: right_index, + }), + BO::LessThan | BO::GreaterThan | BO::LessThanEqual | BO::GreaterThanEqual => + InferenceRule::TriEqualArgs(InferenceRuleTriEqualArgs{ + argument_template: InferenceRuleTemplate::new_template(&NUMBERLIKE_TEMPLATE), + result_template: InferenceRuleTemplate::new_forced(&BOOL_TEMPLATE), + argument1_index: left_index, + argument2_index: right_index, + }), + BO::Add | BO::Subtract | BO::Multiply | BO::Divide => + InferenceRule::TriEqualAll(InferenceRuleTriEqualAll{ + template: InferenceRuleTemplate::new_template(&NUMBERLIKE_TEMPLATE), + argument1_index: left_index, + argument2_index: right_index, + }), + }; + + let node = &mut self.infer_nodes[self_index]; + node.inference_rule = inference_rule; - self.progress_binary_expr(ctx, id) + self.parent_index = old_parent; + self.progress_inference_rule(ctx, self_index)?; + return Ok(self_index); } - fn visit_unary_expr(&mut self, ctx: &mut Ctx, id: UnaryExpressionId) -> VisitorResult { + fn visit_unary_expr(&mut self, ctx: &mut Ctx, id: UnaryExpressionId) -> VisitExprResult { + use UnaryOperator as UO; + let upcast_id = id.upcast(); - self.insert_initial_expr_inference_type(ctx, upcast_id)?; + let self_index = self.insert_initial_inference_node(ctx, upcast_id)?; let unary_expr = &ctx.heap[id]; + let operation = unary_expr.operation; let arg_expr_id = unary_expr.expression; - self.visit_expr(ctx, arg_expr_id)?; + let old_parent = self.parent_index.replace(self_index); + let argument_index = self.visit_expr(ctx, arg_expr_id)?; - self.progress_unary_expr(ctx, id) + let template = match operation { + UO::Positive | UO::Negative => + InferenceRuleTemplate::new_template(&NUMBERLIKE_TEMPLATE), + UO::BitwiseNot => + InferenceRuleTemplate::new_template(&INTEGERLIKE_TEMPLATE), + UO::LogicalNot => + InferenceRuleTemplate::new_forced(&BOOL_TEMPLATE), + }; + + let node = &mut self.infer_nodes[self_index]; + node.inference_rule = InferenceRule::BiEqual(InferenceRuleBiEqual{ + template, argument_index, + }); + + self.parent_index = old_parent; + self.progress_inference_rule_bi_equal(ctx, self_index)?; + return Ok(self_index); } - fn visit_indexing_expr(&mut self, ctx: &mut Ctx, id: IndexingExpressionId) -> VisitorResult { + fn visit_indexing_expr(&mut self, ctx: &mut Ctx, id: IndexingExpressionId) -> VisitExprResult { let upcast_id = id.upcast(); - self.insert_initial_expr_inference_type(ctx, upcast_id)?; + let self_index = self.insert_initial_inference_node(ctx, upcast_id)?; let indexing_expr = &ctx.heap[id]; let subject_expr_id = indexing_expr.subject; let index_expr_id = indexing_expr.index; - self.visit_expr(ctx, subject_expr_id)?; - self.visit_expr(ctx, index_expr_id)?; + let old_parent = self.parent_index.replace(self_index); + let subject_index = self.visit_expr(ctx, subject_expr_id)?; + let index_index = self.visit_expr(ctx, index_expr_id)?; // cool name, bro + + let node = &mut self.infer_nodes[self_index]; + node.inference_rule = InferenceRule::IndexingExpr(InferenceRuleIndexingExpr{ + subject_index, index_index, + }); - self.progress_indexing_expr(ctx, id) + self.parent_index = old_parent; + self.progress_inference_rule_indexing_expr(ctx, self_index)?; + return Ok(self_index); } - fn visit_slicing_expr(&mut self, ctx: &mut Ctx, id: SlicingExpressionId) -> VisitorResult { + fn visit_slicing_expr(&mut self, ctx: &mut Ctx, id: SlicingExpressionId) -> VisitExprResult { let upcast_id = id.upcast(); - self.insert_initial_expr_inference_type(ctx, upcast_id)?; + let self_index = self.insert_initial_inference_node(ctx, upcast_id)?; let slicing_expr = &ctx.heap[id]; let subject_expr_id = slicing_expr.subject; let from_expr_id = slicing_expr.from_index; let to_expr_id = slicing_expr.to_index; - self.visit_expr(ctx, subject_expr_id)?; - self.visit_expr(ctx, from_expr_id)?; - self.visit_expr(ctx, to_expr_id)?; + let old_parent = self.parent_index.replace(self_index); + let subject_index = self.visit_expr(ctx, subject_expr_id)?; + let from_index = self.visit_expr(ctx, from_expr_id)?; + let to_index = self.visit_expr(ctx, to_expr_id)?; - self.progress_slicing_expr(ctx, id) + let node = &mut self.infer_nodes[self_index]; + node.inference_rule = InferenceRule::SlicingExpr(InferenceRuleSlicingExpr{ + subject_index, from_index, to_index, + }); + + self.parent_index = old_parent; + self.progress_inference_rule_slicing_expr(ctx, self_index)?; + return Ok(self_index); } - fn visit_select_expr(&mut self, ctx: &mut Ctx, id: SelectExpressionId) -> VisitorResult { + fn visit_select_expr(&mut self, ctx: &mut Ctx, id: SelectExpressionId) -> VisitExprResult { let upcast_id = id.upcast(); - self.insert_initial_expr_inference_type(ctx, upcast_id)?; + let self_index = self.insert_initial_inference_node(ctx, upcast_id)?; let select_expr = &ctx.heap[id]; let subject_expr_id = select_expr.subject; - self.visit_expr(ctx, subject_expr_id)?; + let old_parent = self.parent_index.replace(self_index); + let subject_index = self.visit_expr(ctx, subject_expr_id)?; + + let node = &mut self.infer_nodes[self_index]; + let inference_rule = match &ctx.heap[id].kind { + SelectKind::StructField(field_identifier) => + InferenceRule::SelectStructField(InferenceRuleSelectStructField{ + subject_index, + selected_field: field_identifier.clone(), + }), + SelectKind::TupleMember(member_index) => + InferenceRule::SelectTupleMember(InferenceRuleSelectTupleMember{ + subject_index, + selected_index: *member_index, + }), + }; + node.inference_rule = inference_rule; - self.progress_select_expr(ctx, id) + self.parent_index = old_parent; + self.progress_inference_rule(ctx, self_index)?; + return Ok(self_index); } - fn visit_literal_expr(&mut self, ctx: &mut Ctx, id: LiteralExpressionId) -> VisitorResult { + fn visit_literal_expr(&mut self, ctx: &mut Ctx, id: LiteralExpressionId) -> VisitExprResult { let upcast_id = id.upcast(); - self.insert_initial_expr_inference_type(ctx, upcast_id)?; + let self_index = self.insert_initial_inference_node(ctx, upcast_id)?; + + let old_parent = self.parent_index.replace(self_index); let literal_expr = &ctx.heap[id]; match &literal_expr.value { - Literal::Null | Literal::False | Literal::True | - Literal::Integer(_) | Literal::Character(_) | Literal::String(_) => { - // No subexpressions + Literal::Null => { + let node = &mut self.infer_nodes[self_index]; + node.inference_rule = InferenceRule::MonoTemplate(InferenceRuleTemplate::new_template(&MESSAGE_TEMPLATE)); + }, + Literal::Integer(_) => { + let node = &mut self.infer_nodes[self_index]; + node.inference_rule = InferenceRule::MonoTemplate(InferenceRuleTemplate::new_template(&INTEGERLIKE_TEMPLATE)); + }, + Literal::True | Literal::False => { + let node = &mut self.infer_nodes[self_index]; + node.inference_rule = InferenceRule::MonoTemplate(InferenceRuleTemplate::new_forced(&BOOL_TEMPLATE)); + }, + Literal::Character(_) => { + let node = &mut self.infer_nodes[self_index]; + node.inference_rule = InferenceRule::MonoTemplate(InferenceRuleTemplate::new_forced(&CHARACTER_TEMPLATE)); + }, + Literal::String(_) => { + let node = &mut self.infer_nodes[self_index]; + node.inference_rule = InferenceRule::MonoTemplate(InferenceRuleTemplate::new_forced(&STRING_TEMPLATE)); }, Literal::Struct(literal) => { + // Visit field expressions let mut expr_ids = self.expr_buffer.start_section(); for field in &literal.fields { expr_ids.push(field.value); } - self.insert_initial_struct_polymorph_data(ctx, id); + let mut expr_indices = self.index_buffer.start_section(); for expr_id in expr_ids.iter_copied() { - self.visit_expr(ctx, expr_id)?; + let expr_index = self.visit_expr(ctx, expr_id)?; + expr_indices.push(expr_index); } expr_ids.forget(); + let element_indices = expr_indices.into_vec(); + + // Assign rule and extra data index to inference node + let poly_data_index = self.insert_initial_struct_polymorph_data(ctx, id); + let node = &mut self.infer_nodes[self_index]; + node.poly_data_index = poly_data_index; + node.inference_rule = InferenceRule::LiteralStruct(InferenceRuleLiteralStruct{ + element_indices, + }); }, Literal::Enum(_) => { // Enumerations do not carry any subexpressions, but may still // have a user-defined polymorphic marker variable. For this // reason we may still have to apply inference to this // polymorphic variable - self.insert_initial_enum_polymorph_data(ctx, id); + let poly_data_index = self.insert_initial_enum_polymorph_data(ctx, id); + let node = &mut self.infer_nodes[self_index]; + node.poly_data_index = poly_data_index; + node.inference_rule = InferenceRule::LiteralEnum; }, Literal::Union(literal) => { // May carry subexpressions and polymorphic arguments let expr_ids = self.expr_buffer.start_section_initialized(literal.values.as_slice()); - self.insert_initial_union_polymorph_data(ctx, id); + let poly_data_index = self.insert_initial_union_polymorph_data(ctx, id); + + let mut expr_indices = self.index_buffer.start_section(); + for expr_id in expr_ids.iter_copied() { + let expr_index = self.visit_expr(ctx, expr_id)?; + expr_indices.push(expr_index); + } + expr_ids.forget(); + let element_indices = expr_indices.into_vec(); + let node = &mut self.infer_nodes[self_index]; + node.poly_data_index = poly_data_index; + node.inference_rule = InferenceRule::LiteralUnion(InferenceRuleLiteralUnion{ + element_indices, + }); + }, + Literal::Array(expressions) => { + let expr_ids = self.expr_buffer.start_section_initialized(expressions.as_slice()); + + let mut expr_indices = self.index_buffer.start_section(); for expr_id in expr_ids.iter_copied() { - self.visit_expr(ctx, expr_id)?; + let expr_index = self.visit_expr(ctx, expr_id)?; + expr_indices.push(expr_index); } expr_ids.forget(); + let element_indices = expr_indices.into_vec(); + + let node = &mut self.infer_nodes[self_index]; + node.inference_rule = InferenceRule::LiteralArray(InferenceRuleLiteralArray{ + element_indices, + }); }, - Literal::Array(expressions) | Literal::Tuple(expressions) => { + Literal::Tuple(expressions) => { let expr_ids = self.expr_buffer.start_section_initialized(expressions.as_slice()); + + let mut expr_indices = self.index_buffer.start_section(); for expr_id in expr_ids.iter_copied() { - self.visit_expr(ctx, expr_id)?; + let expr_index = self.visit_expr(ctx, expr_id)?; + expr_indices.push(expr_index); } expr_ids.forget(); + let element_indices = expr_indices.into_vec(); + + let node = &mut self.infer_nodes[self_index]; + node.inference_rule = InferenceRule::LiteralTuple(InferenceRuleLiteralTuple{ + element_indices, + }) } } - self.progress_literal_expr(ctx, id) + self.parent_index = old_parent; + self.progress_inference_rule(ctx, self_index)?; + return Ok(self_index); } - fn visit_cast_expr(&mut self, ctx: &mut Ctx, id: CastExpressionId) -> VisitorResult { + fn visit_cast_expr(&mut self, ctx: &mut Ctx, id: CastExpressionId) -> VisitExprResult { let upcast_id = id.upcast(); - self.insert_initial_expr_inference_type(ctx, upcast_id)?; + let self_index = self.insert_initial_inference_node(ctx, upcast_id)?; let cast_expr = &ctx.heap[id]; let subject_expr_id = cast_expr.subject; - self.visit_expr(ctx, subject_expr_id)?; + let old_parent = self.parent_index.replace(self_index); + let subject_index = self.visit_expr(ctx, subject_expr_id)?; + + let node = &mut self.infer_nodes[self_index]; + node.inference_rule = InferenceRule::CastExpr(InferenceRuleCastExpr{ + subject_index, + }); - self.progress_cast_expr(ctx, id) + self.parent_index = old_parent; + + // The cast expression is a bit special at this point: the progression + // function simply makes sure input/output types are compatible. But if + // the programmer explicitly specified the output type, then we can + // already perform that inference rule here. + { + let cast_expr = &ctx.heap[id]; + let specified_type = self.determine_inference_type_from_parser_type_elements(&cast_expr.to_type.elements, true); + let _progress = self.apply_template_constraint(ctx, self_index, &specified_type.parts)?; + } + + self.progress_inference_rule_cast_expr(ctx, self_index)?; + return Ok(self_index); } - fn visit_call_expr(&mut self, ctx: &mut Ctx, id: CallExpressionId) -> VisitorResult { + fn visit_call_expr(&mut self, ctx: &mut Ctx, id: CallExpressionId) -> VisitExprResult { let upcast_id = id.upcast(); - self.insert_initial_expr_inference_type(ctx, upcast_id)?; - self.insert_initial_call_polymorph_data(ctx, id); + let self_index = self.insert_initial_inference_node(ctx, upcast_id)?; + let extra_index = self.insert_initial_call_polymorph_data(ctx, id); - // By default we set the polymorph idx for calls to 0. If the call ends - // up not being a polymorphic one, then we will select the default - // expression types in the type table - let call_expr = &ctx.heap[id]; - self.expr_types[call_expr.unique_id_in_definition as usize].field_or_monomorph_idx = 0; + // By default we set the polymorph idx for calls to 0. If the call + // refers to a non-polymorphic function, then it will be "monomorphed" + // once, hence we end up pointing to the correct instance. + self.infer_nodes[self_index].field_index = 0; // Visit all arguments + let old_parent = self.parent_index.replace(self_index); + + let call_expr = &ctx.heap[id]; let expr_ids = self.expr_buffer.start_section_initialized(call_expr.arguments.as_slice()); + let mut expr_indices = self.index_buffer.start_section(); + for arg_expr_id in expr_ids.iter_copied() { - self.visit_expr(ctx, arg_expr_id)?; + let expr_index = self.visit_expr(ctx, arg_expr_id)?; + expr_indices.push(expr_index); } expr_ids.forget(); + let argument_indices = expr_indices.into_vec(); - self.progress_call_expr(ctx, id) + let node = &mut self.infer_nodes[self_index]; + node.poly_data_index = extra_index; + node.inference_rule = InferenceRule::CallExpr(InferenceRuleCallExpr{ + argument_indices, + }); + + self.parent_index = old_parent; + self.progress_inference_rule_call_expr(ctx, self_index)?; + return Ok(self_index); } - fn visit_variable_expr(&mut self, ctx: &mut Ctx, id: VariableExpressionId) -> VisitorResult { + fn visit_variable_expr(&mut self, ctx: &mut Ctx, id: VariableExpressionId) -> VisitExprResult { let upcast_id = id.upcast(); - self.insert_initial_expr_inference_type(ctx, upcast_id)?; + let self_index = self.insert_initial_inference_node(ctx, upcast_id)?; let var_expr = &ctx.heap[id]; debug_assert!(var_expr.declaration.is_some()); + let old_parent = self.parent_index.replace(self_index); - // Not pretty: if a binding expression, then this is the first time we - // encounter the variable, so we still need to insert the variable data. let declaration = &ctx.heap[var_expr.declaration.unwrap()]; - if !self.var_types.contains_key(&declaration.this) { - debug_assert!(declaration.kind == VariableKind::Binding); + let mut var_data_index = None; + for (index, var_data) in self.var_data.iter().enumerate() { + if var_data.var_id == declaration.this { + var_data_index = Some(index); + break; + } + } + + let var_data_index = if let Some(var_data_index) = var_data_index { + let var_data = &mut self.var_data[var_data_index]; + var_data.used_at.push(self_index); + + var_data_index + } else { + // If we're in a binding expression then it might the first time we + // encounter the variable, so add a `VarData` entry. + debug_assert_eq!(declaration.kind, VariableKind::Binding); let var_type = self.determine_inference_type_from_parser_type_elements( &declaration.parser_type.elements, true ); - self.var_types.insert(declaration.this, VarData{ + let var_data_index = self.var_data.len(); + self.var_data.push(VarData{ + var_id: declaration.this, var_type, - used_at: vec![upcast_id], - linked_var: None + used_at: vec![self_index], + linked_var: None, }); - } else { - let var_data = self.var_types.get_mut(&declaration.this).unwrap(); - var_data.used_at.push(upcast_id); - } - self.progress_variable_expr(ctx, id) + var_data_index + }; + + let node = &mut self.infer_nodes[self_index]; + node.inference_rule = InferenceRule::VariableExpr(InferenceRuleVariableExpr{ + var_data_index, + }); + + self.parent_index = old_parent; + self.progress_inference_rule_variable_expr(ctx, self_index)?; + return Ok(self_index); } } +// ----------------------------------------------------------------------------- +// PassTyping - Type-inference progression +// ----------------------------------------------------------------------------- + impl PassTyping { #[allow(dead_code)] // used when debug flag at the top of this file is true. - fn debug_get_display_name(&self, ctx: &Ctx, expr_id: ExpressionId) -> String { - let expr_idx = ctx.heap[expr_id].get_unique_id_in_definition(); - let expr_type = &self.expr_types[expr_idx as usize].expr_type; + fn debug_get_display_name(&self, ctx: &Ctx, node_index: InferNodeIndex) -> String { + let expr_type = &self.infer_nodes[node_index].expr_type; expr_type.display_name(&ctx.heap) } fn resolve_types(&mut self, ctx: &mut Ctx, queue: &mut ResolveQueue) -> Result<(), ParseError> { // Keep inferring until we can no longer make any progress - while !self.expr_queued.is_empty() { - // Make as much progress as possible without forced integer - // inference. - while !self.expr_queued.is_empty() { - let next_expr_idx = self.expr_queued.pop_front().unwrap(); - self.progress_expr(ctx, next_expr_idx)?; + while !self.node_queued.is_empty() { + while !self.node_queued.is_empty() { + let node_index = self.node_queued.pop_front().unwrap(); + self.progress_inference_rule(ctx, node_index)?; } // Nothing is queued anymore. However we might have integer literals // whose type cannot be inferred. For convenience's sake we'll // infer these to be s32. - for (infer_expr_idx, infer_expr) in self.expr_types.iter_mut().enumerate() { - let expr_type = &mut infer_expr.expr_type; + for (infer_node_index, infer_node) in self.infer_nodes.iter_mut().enumerate() { + let expr_type = &mut infer_node.expr_type; if !expr_type.is_done && expr_type.parts.len() == 1 && expr_type.parts[0] == InferenceTypePart::IntegerLike { // Force integer type to s32 expr_type.parts[0] = InferenceTypePart::SInt32; expr_type.is_done = true; // Requeue expression (and its parent, if it exists) - self.expr_queued.push_back(infer_expr_idx as i32); - - if let Some(parent_expr) = ctx.heap[infer_expr.expr_id].parent_expr_id() { - let parent_idx = ctx.heap[parent_expr].get_unique_id_in_definition(); - self.expr_queued.push_back(parent_idx); + self.node_queued.push_back(infer_node_index); + if let Some(node_parent_index) = infer_node.parent_index { + self.node_queued.push_back(node_parent_index); } } } @@ -1489,13 +1973,13 @@ impl PassTyping { // Helper for transferring polymorphic variables to concrete types and // checking if they're completely specified - fn inference_type_to_concrete_type( - ctx: &Ctx, expr_id: ExpressionId, inference: &Vec, + fn poly_data_type_to_concrete_type( + ctx: &Ctx, expr_id: ExpressionId, inference_poly_args: &Vec, first_concrete_part: ConcreteTypePart, ) -> Result { // Prepare storage vector let mut num_inference_parts = 0; - for inference_type in inference { + for inference_type in inference_poly_args { num_inference_parts += inference_type.parts.len(); } @@ -1506,11 +1990,11 @@ impl PassTyping { // Go through all polymorphic arguments and add them to the concrete // types. - for (poly_idx, poly_type) in inference.iter().enumerate() { + for (poly_idx, poly_type) in inference_poly_args.iter().enumerate() { if !poly_type.is_done { let expr = &ctx.heap[expr_id]; let definition = match expr { - Expression::Call(expr) => expr.definition, + Expression::Call(expr) => expr.procedure.upcast(), Expression::Literal(expr) => match &expr.value { Literal::Enum(lit) => lit.definition, Literal::Union(lit) => lit.definition, @@ -1534,1430 +2018,1005 @@ impl PassTyping { Ok(concrete_type) } - // Inference is now done. But we may still have uninferred types. So we - // check for these. - for infer_expr in self.expr_types.iter_mut() { - if !infer_expr.expr_type.is_done { - let expr = &ctx.heap[infer_expr.expr_id]; - return Err(ParseError::new_error_at_span( - &ctx.module().source, expr.full_span(), format!( - "could not fully infer the type of this expression (got '{}')", - infer_expr.expr_type.display_name(&ctx.heap) - ) - )); - } - - // Expression is fine, check if any extra data is attached - if infer_expr.extra_data_idx < 0 { continue; } + // Every expression checked, and new monomorphs are queued. Transfer the + // expression information to the AST. If this is the first time we're + // visiting this procedure then we assign expression indices as well. + let procedure = &ctx.heap[self.procedure_id]; + let num_infer_nodes = self.infer_nodes.len(); + let mut monomorph = ProcedureDefinitionMonomorph{ + argument_types: Vec::with_capacity(procedure.parameters.len()), + expr_info: Vec::with_capacity(num_infer_nodes), + }; - // Extra data is attached, perform typechecking and transfer - // resolved information to the expression - let extra_data = &self.extra_data[infer_expr.extra_data_idx as usize]; + // For all of the expressions look up the TypeId (or create a new one). + // For function calls and component instantiations figure out if they + // need to be typechecked + for infer_node in self.infer_nodes.iter_mut() { + // Determine type ID + let expr = &ctx.heap[infer_node.expr_id]; + + // TODO: Maybe optimize? Split insertion up into lookup, then clone + // if needed? + let mut concrete_type = ConcreteType::default(); + infer_node.expr_type.write_concrete_type(&mut concrete_type); + let info_type_id = ctx.types.add_monomorphed_type(ctx.modules, ctx.heap, ctx.arch, concrete_type)?; + + // Determine procedure type ID, i.e. a called/instantiated + // procedure's signature. + let info_variant = if let Expression::Call(expr) = expr { + // Construct full function type. If not yet typechecked then + // queue it for typechecking. + let poly_data = &self.poly_data[infer_node.poly_data_index as usize]; + debug_assert!(expr.method.is_user_defined() || expr.method.is_public_builtin()); + let procedure_id = expr.procedure; + let num_poly_vars = poly_data.poly_vars.len() as u32; + + let first_part = match expr.method { + Method::UserFunction => ConcreteTypePart::Function(procedure_id, num_poly_vars), + Method::UserComponent => ConcreteTypePart::Component(procedure_id, num_poly_vars), + _ => ConcreteTypePart::Function(procedure_id, num_poly_vars), + }; - // Note that only call and literal expressions need full inference. - // Select expressions also use `extra_data`, but only for temporary - // storage of the struct type whose field it is selecting. - match &ctx.heap[extra_data.expr_id] { - Expression::Call(expr) => { - // Check if it is not a builtin function. If not, then - // construct the first part of the concrete type. - let first_concrete_part = if expr.method == Method::UserFunction { - ConcreteTypePart::Function(expr.definition, extra_data.poly_vars.len() as u32) - } else if expr.method == Method::UserComponent { - ConcreteTypePart::Component(expr.definition, extra_data.poly_vars.len() as u32) - } else { - // Builtin function - continue; - }; - let definition_id = expr.definition; - let concrete_type = inference_type_to_concrete_type( - ctx, extra_data.expr_id, &extra_data.poly_vars, first_concrete_part - )?; + let definition_id = procedure_id.upcast(); + let signature_type = poly_data_type_to_concrete_type( + ctx, infer_node.expr_id, &poly_data.poly_vars, first_part + )?; - match ctx.types.get_procedure_monomorph_index(&definition_id, &concrete_type.parts) { - Some(reserved_idx) => { - // Already typechecked, or already put into the resolve queue - infer_expr.field_or_monomorph_idx = reserved_idx; - }, - None => { - // Not typechecked yet, so add an entry in the queue - let reserved_idx = ctx.types.reserve_procedure_monomorph_index(&definition_id, concrete_type); - infer_expr.field_or_monomorph_idx = reserved_idx; - queue.push(ResolveQueueElement { - root_id: ctx.heap[definition_id].defined_in(), - definition_id, - reserved_monomorph_idx: reserved_idx, - }); - } + let (type_id, monomorph_index) = if let Some(type_id) = ctx.types.get_procedure_monomorph_type_id(&definition_id, &signature_type.parts) { + // Procedure is already typechecked + let monomorph_index = ctx.types.get_monomorph(type_id).variant.as_procedure().monomorph_index; + (type_id, monomorph_index) + } else { + // Procedure is not yet typechecked, reserve a TypeID and a monomorph index + let procedure_to_check = &mut ctx.heap[procedure_id]; + let monomorph_index = procedure_to_check.monomorphs.len() as u32; + procedure_to_check.monomorphs.push(ProcedureDefinitionMonomorph::new_invalid()); + let type_id = ctx.types.reserve_procedure_monomorph_type_id(&definition_id, signature_type, monomorph_index); + + if !procedure_to_check.builtin { + // Only perform typechecking on the user-defined + // procedures + queue.push_back(ResolveQueueElement{ + root_id: ctx.heap[definition_id].defined_in(), + definition_id, + reserved_type_id: type_id, + reserved_monomorph_index: monomorph_index, + }); } - }, - Expression::Literal(expr) => { - let definition_id = match &expr.value { - Literal::Enum(lit) => lit.definition, - Literal::Union(lit) => lit.definition, - Literal::Struct(lit) => lit.definition, - _ => unreachable!(), - }; - let first_concrete_part = ConcreteTypePart::Instance(definition_id, extra_data.poly_vars.len() as u32); - let concrete_type = inference_type_to_concrete_type( - ctx, extra_data.expr_id, &extra_data.poly_vars, first_concrete_part - )?; - let mono_index = ctx.types.add_data_monomorph(ctx.modules, ctx.heap, ctx.arch, definition_id, concrete_type)?; - infer_expr.field_or_monomorph_idx = mono_index; - }, - Expression::Select(_) => { - debug_assert!(infer_expr.field_or_monomorph_idx >= 0); - }, - _ => { - unreachable!("handling extra data for expression {:?}", &ctx.heap[extra_data.expr_id]); - } - } - } - // Every expression checked, and new monomorphs are queued. Transfer the - // expression information to the type table. - let procedure_arguments = match &self.definition_type { - DefinitionType::Component(id) => { - let definition = &ctx.heap[*id]; - &definition.parameters - }, - DefinitionType::Function(id) => { - let definition = &ctx.heap[*id]; - &definition.parameters - }, - }; + (type_id, monomorph_index) + }; - let target = ctx.types.get_procedure_monomorph_mut(self.reserved_idx); - debug_assert!(target.arg_types.is_empty()); // makes sure we never queue a procedure's type inferencing twice - debug_assert!(target.expr_data.is_empty()); + ExpressionInfoVariant::Procedure(type_id, monomorph_index) + } else if let Expression::Select(_expr) = expr { + ExpressionInfoVariant::Select(infer_node.field_index) + } else { + ExpressionInfoVariant::Generic + }; - // - Write the arguments to the procedure - target.arg_types.reserve(procedure_arguments.len()); - for argument_id in procedure_arguments { - let mut concrete = ConcreteType::default(); - let argument_type = self.var_types.get(argument_id).unwrap(); - argument_type.var_type.write_concrete_type(&mut concrete); - target.arg_types.push(concrete); + infer_node.info_type_id = info_type_id; + infer_node.info_variant = info_variant; } - // - Write the expression data - target.expr_data.reserve(self.expr_types.len()); - for infer_expr in self.expr_types.iter() { + // Write the types of the arguments + let procedure = &ctx.heap[self.procedure_id]; + for parameter_id in procedure.parameters.iter().copied() { let mut concrete = ConcreteType::default(); - infer_expr.expr_type.write_concrete_type(&mut concrete); - target.expr_data.push(MonomorphExpression{ - expr_type: concrete, - field_or_monomorph_idx: infer_expr.field_or_monomorph_idx - }); + let var_data = self.var_data.iter().find(|v| v.var_id == parameter_id).unwrap(); + var_data.var_type.write_concrete_type(&mut concrete); + let type_id = ctx.types.add_monomorphed_type(ctx.modules, ctx.heap, ctx.arch, concrete)?; + monomorph.argument_types.push(type_id) } - Ok(()) - } - - fn progress_expr(&mut self, ctx: &mut Ctx, idx: i32) -> Result<(), ParseError> { - let id = self.expr_types[idx as usize].expr_id; - match &ctx.heap[id] { - Expression::Assignment(expr) => { - let id = expr.this; - self.progress_assignment_expr(ctx, id) - }, - Expression::Binding(expr) => { - let id = expr.this; - self.progress_binding_expr(ctx, id) - }, - Expression::Conditional(expr) => { - let id = expr.this; - self.progress_conditional_expr(ctx, id) - }, - Expression::Binary(expr) => { - let id = expr.this; - self.progress_binary_expr(ctx, id) - }, - Expression::Unary(expr) => { - let id = expr.this; - self.progress_unary_expr(ctx, id) - }, - Expression::Indexing(expr) => { - let id = expr.this; - self.progress_indexing_expr(ctx, id) - }, - Expression::Slicing(expr) => { - let id = expr.this; - self.progress_slicing_expr(ctx, id) - }, - Expression::Select(expr) => { - let id = expr.this; - self.progress_select_expr(ctx, id) - }, - Expression::Literal(expr) => { - let id = expr.this; - self.progress_literal_expr(ctx, id) - }, - Expression::Cast(expr) => { - let id = expr.this; - self.progress_cast_expr(ctx, id) - }, - Expression::Call(expr) => { - let id = expr.this; - self.progress_call_expr(ctx, id) - }, - Expression::Variable(expr) => { - let id = expr.this; - self.progress_variable_expr(ctx, id) + // Determine if we have already assigned type indices to the expressions + // before (the indices that, for a monomorph, can retrieve the type of + // the expression). + let has_type_indices = self.reserved_monomorph_index > 0; + if has_type_indices { + // already have indices, so resize and then index into it + debug_assert!(monomorph.expr_info.is_empty()); + monomorph.expr_info.resize(num_infer_nodes, ExpressionInfo::new_invalid()); + for infer_node in self.infer_nodes.iter() { + let type_index = ctx.heap[infer_node.expr_id].type_index(); + monomorph.expr_info[type_index as usize] = infer_node.as_expression_info(); + } + } else { + // no indices yet, need to be assigned in AST + for infer_node in self.infer_nodes.iter() { + let type_index = monomorph.expr_info.len(); + monomorph.expr_info.push(infer_node.as_expression_info()); + *ctx.heap[infer_node.expr_id].type_index_mut() = type_index as i32; } } - } - fn progress_assignment_expr(&mut self, ctx: &mut Ctx, id: AssignmentExpressionId) -> Result<(), ParseError> { - use AssignmentOperator as AO; + // Push the information into the AST + let procedure = &mut ctx.heap[self.procedure_id]; + procedure.monomorphs[self.reserved_monomorph_index as usize] = monomorph; - let upcast_id = id.upcast(); + Ok(()) + } - let expr = &ctx.heap[id]; - let arg1_expr_id = expr.left; - let arg2_expr_id = expr.right; + fn progress_inference_rule(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + use InferenceRule as IR; - debug_log!("Assignment expr '{:?}': {}", expr.operation, upcast_id.index); - debug_log!(" * Before:"); - debug_log!(" - Arg1 type: {}", self.debug_get_display_name(ctx, arg1_expr_id)); - debug_log!(" - Arg2 type: {}", self.debug_get_display_name(ctx, arg2_expr_id)); - debug_log!(" - Expr type: {}", self.debug_get_display_name(ctx, upcast_id)); + let node = &self.infer_nodes[node_index]; + match &node.inference_rule { + IR::Noop => + unreachable!(), + IR::MonoTemplate(_) => + self.progress_inference_rule_mono_template(ctx, node_index), + IR::BiEqual(_) => + self.progress_inference_rule_bi_equal(ctx, node_index), + IR::TriEqualArgs(_) => + self.progress_inference_rule_tri_equal_args(ctx, node_index), + IR::TriEqualAll(_) => + self.progress_inference_rule_tri_equal_all(ctx, node_index), + IR::Concatenate(_) => + self.progress_inference_rule_concatenate(ctx, node_index), + IR::IndexingExpr(_) => + self.progress_inference_rule_indexing_expr(ctx, node_index), + IR::SlicingExpr(_) => + self.progress_inference_rule_slicing_expr(ctx, node_index), + IR::SelectStructField(_) => + self.progress_inference_rule_select_struct_field(ctx, node_index), + IR::SelectTupleMember(_) => + self.progress_inference_rule_select_tuple_member(ctx, node_index), + IR::LiteralStruct(_) => + self.progress_inference_rule_literal_struct(ctx, node_index), + IR::LiteralEnum => + self.progress_inference_rule_literal_enum(ctx, node_index), + IR::LiteralUnion(_) => + self.progress_inference_rule_literal_union(ctx, node_index), + IR::LiteralArray(_) => + self.progress_inference_rule_literal_array(ctx, node_index), + IR::LiteralTuple(_) => + self.progress_inference_rule_literal_tuple(ctx, node_index), + IR::CastExpr(_) => + self.progress_inference_rule_cast_expr(ctx, node_index), + IR::CallExpr(_) => + self.progress_inference_rule_call_expr(ctx, node_index), + IR::VariableExpr(_) => + self.progress_inference_rule_variable_expr(ctx, node_index), + } + } - // Assignment does not return anything (it operates like a statement) - let progress_expr = self.apply_forced_constraint(ctx, upcast_id, &VOID_TEMPLATE)?; + fn progress_inference_rule_mono_template(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let rule = *node.inference_rule.as_mono_template(); - // Apply forced constraint to LHS value - let progress_forced = match expr.operation { - AO::Set => - false, - AO::Concatenated => - self.apply_template_constraint(ctx, arg1_expr_id, &ARRAYLIKE_TEMPLATE)?, - AO::Multiplied | AO::Divided | AO::Added | AO::Subtracted => - self.apply_template_constraint(ctx, arg1_expr_id, &NUMBERLIKE_TEMPLATE)?, - AO::Remained | AO::ShiftedLeft | AO::ShiftedRight | - AO::BitwiseAnded | AO::BitwiseXored | AO::BitwiseOred => - self.apply_template_constraint(ctx, arg1_expr_id, &INTEGERLIKE_TEMPLATE)?, - }; + let progress = self.progress_template(ctx, node_index, rule.application, rule.template)?; + if progress { self.queue_node_parent(node_index); } - let (progress_arg1, progress_arg2) = self.apply_equal2_constraint( - ctx, upcast_id, arg1_expr_id, 0, arg2_expr_id, 0 - )?; + return Ok(()); + } - debug_log!(" * After:"); - debug_log!(" - Arg1 type [{}]: {}", progress_forced || progress_arg1, self.debug_get_display_name(ctx, arg1_expr_id)); - debug_log!(" - Arg2 type [{}]: {}", progress_arg2, self.debug_get_display_name(ctx, arg2_expr_id)); - debug_log!(" - Expr type [{}]: {}", progress_expr, self.debug_get_display_name(ctx, upcast_id)); + fn progress_inference_rule_bi_equal(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let rule = node.inference_rule.as_bi_equal(); + let template = rule.template; + let arg_index = rule.argument_index; + let base_progress = self.progress_template(ctx, node_index, template.application, template.template)?; + let (node_progress, arg_progress) = self.apply_equal2_constraint(ctx, node_index, node_index, 0, arg_index, 0)?; - if progress_expr { self.queue_expr_parent(ctx, upcast_id); } - if progress_forced || progress_arg1 { self.queue_expr(ctx, arg1_expr_id); } - if progress_arg2 { self.queue_expr(ctx, arg2_expr_id); } + if base_progress || node_progress { self.queue_node_parent(node_index); } + if arg_progress { self.queue_node(arg_index); } - Ok(()) + return Ok(()) } - fn progress_binding_expr(&mut self, ctx: &mut Ctx, id: BindingExpressionId) -> Result<(), ParseError> { - let upcast_id = id.upcast(); - let binding_expr = &ctx.heap[id]; - let bound_from_id = binding_expr.bound_from; - let bound_to_id = binding_expr.bound_to; + fn progress_inference_rule_tri_equal_args(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let rule = node.inference_rule.as_tri_equal_args(); - // Output is always a boolean. The two arguments should be of equal - // type. - let progress_expr = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?; - let (progress_from, progress_to) = self.apply_equal2_constraint(ctx, upcast_id, bound_from_id, 0, bound_to_id, 0)?; + let result_template = rule.result_template; + let argument_template = rule.argument_template; + let arg1_index = rule.argument1_index; + let arg2_index = rule.argument2_index; - if progress_expr { self.queue_expr_parent(ctx, upcast_id); } - if progress_from { self.queue_expr(ctx, bound_from_id); } - if progress_to { self.queue_expr(ctx, bound_to_id); } + let self_template_progress = self.progress_template(ctx, node_index, result_template.application, result_template.template)?; + let arg1_template_progress = self.progress_template(ctx, arg1_index, argument_template.application, argument_template.template)?; + let (arg1_progress, arg2_progress) = self.apply_equal2_constraint(ctx, node_index, arg1_index, 0, arg2_index, 0)?; - Ok(()) - } - - fn progress_conditional_expr(&mut self, ctx: &mut Ctx, id: ConditionalExpressionId) -> Result<(), ParseError> { - // Note: test expression type is already enforced - let upcast_id = id.upcast(); - let expr = &ctx.heap[id]; - let arg1_expr_id = expr.true_expression; - let arg2_expr_id = expr.false_expression; - - debug_log!("Conditional expr: {}", upcast_id.index); - debug_log!(" * Before:"); - debug_log!(" - Arg1 type: {}", self.debug_get_display_name(ctx, arg1_expr_id)); - debug_log!(" - Arg2 type: {}", self.debug_get_display_name(ctx, arg2_expr_id)); - debug_log!(" - Expr type: {}", self.debug_get_display_name(ctx, upcast_id)); - - // I keep confusing myself: this applies equality of types between the - // condition branches' types, and the result from the conditional - // expression, because the result from the conditional is one of the - // branches. - let (progress_expr, progress_arg1, progress_arg2) = self.apply_equal3_constraint( - ctx, upcast_id, arg1_expr_id, arg2_expr_id, 0 - )?; - - debug_log!(" * After:"); - debug_log!(" - Arg1 type [{}]: {}", progress_arg1, self.debug_get_display_name(ctx, arg1_expr_id)); - debug_log!(" - Arg2 type [{}]: {}", progress_arg2, self.debug_get_display_name(ctx, arg2_expr_id)); - debug_log!(" - Expr type [{}]: {}", progress_expr, self.debug_get_display_name(ctx, upcast_id)); - - if progress_expr { self.queue_expr_parent(ctx, upcast_id); } - if progress_arg1 { self.queue_expr(ctx, arg1_expr_id); } - if progress_arg2 { self.queue_expr(ctx, arg2_expr_id); } + if self_template_progress { self.queue_node_parent(node_index); } + if arg1_template_progress || arg1_progress { self.queue_node(arg1_index); } + if arg2_progress { self.queue_node(arg2_index); } - Ok(()) + return Ok(()); } - fn progress_binary_expr(&mut self, ctx: &mut Ctx, id: BinaryExpressionId) -> Result<(), ParseError> { - // Note: our expression type might be fixed by our parent, but we still - // need to make sure it matches the type associated with our operation. - use BinaryOperator as BO; + fn progress_inference_rule_tri_equal_all(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let rule = node.inference_rule.as_tri_equal_all(); - let upcast_id = id.upcast(); - let expr = &ctx.heap[id]; - let arg1_id = expr.left; - let arg2_id = expr.right; - - debug_log!("Binary expr '{:?}': {}", expr.operation, upcast_id.index); - debug_log!(" * Before:"); - debug_log!(" - Arg1 type: {}", self.debug_get_display_name(ctx, arg1_id)); - debug_log!(" - Arg2 type: {}", self.debug_get_display_name(ctx, arg2_id)); - debug_log!(" - Expr type: {}", self.debug_get_display_name(ctx, upcast_id)); - - let (progress_expr, progress_arg1, progress_arg2) = match expr.operation { - BO::Concatenate => { - // Two cases: if one of the arguments or the output type is a - // string, then all must be strings. Otherwise the arguments - // must be arraylike and the output will be a array. - let (expr_is_str, expr_is_not_str) = self.type_is_certainly_or_certainly_not_string(ctx, upcast_id); - let (arg1_is_str, arg1_is_not_str) = self.type_is_certainly_or_certainly_not_string(ctx, arg1_id); - let (arg2_is_str, arg2_is_not_str) = self.type_is_certainly_or_certainly_not_string(ctx, arg2_id); - - let someone_is_str = expr_is_str || arg1_is_str || arg2_is_str; - let someone_is_not_str = expr_is_not_str || arg1_is_not_str || arg2_is_not_str; - - // Note: this statement is an expression returning the progression bools - if someone_is_str { - // One of the arguments is a string, then all must be strings - self.apply_equal3_constraint(ctx, upcast_id, arg1_id, arg2_id, 0)? - } else { - let progress_expr = if someone_is_not_str { - // Output must be a normal array - self.apply_template_constraint(ctx, upcast_id, &ARRAY_TEMPLATE)? - } else { - // Output may still be anything - self.apply_template_constraint(ctx, upcast_id, &ARRAYLIKE_TEMPLATE)? - }; - - let progress_arg1 = self.apply_template_constraint(ctx, arg1_id, &ARRAYLIKE_TEMPLATE)?; - let progress_arg2 = self.apply_template_constraint(ctx, arg2_id, &ARRAYLIKE_TEMPLATE)?; - - // If they're all arraylike, then we want the subtype to match - let (subtype_expr, subtype_arg1, subtype_arg2) = - self.apply_equal3_constraint(ctx, upcast_id, arg1_id, arg2_id, 1)?; - - (progress_expr || subtype_expr, progress_arg1 || subtype_arg1, progress_arg2 || subtype_arg2) - } - }, - BO::LogicalAnd => { - // Forced boolean on all - let progress_expr = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?; - let progress_arg1 = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?; - let progress_arg2 = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?; + let template = rule.template; + let arg1_index = rule.argument1_index; + let arg2_index = rule.argument2_index; - (progress_expr, progress_arg1, progress_arg2) - }, - BO::LogicalOr => { - // Forced boolean on all - let progress_expr = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?; - let progress_arg1 = self.apply_forced_constraint(ctx, arg1_id, &BOOL_TEMPLATE)?; - let progress_arg2 = self.apply_forced_constraint(ctx, arg2_id, &BOOL_TEMPLATE)?; - - (progress_expr, progress_arg1, progress_arg2) - }, - BO::BitwiseOr | BO::BitwiseXor | BO::BitwiseAnd | BO::Remainder | BO::ShiftLeft | BO::ShiftRight => { - // All equal of integer type - let progress_base = self.apply_template_constraint(ctx, upcast_id, &INTEGERLIKE_TEMPLATE)?; - let (progress_expr, progress_arg1, progress_arg2) = - self.apply_equal3_constraint(ctx, upcast_id, arg1_id, arg2_id, 0)?; + let template_progress = self.progress_template(ctx, node_index, template.application, template.template)?; + let (node_progress, arg1_progress, arg2_progress) = + self.apply_equal3_constraint(ctx, node_index, arg1_index, arg2_index, 0)?; - (progress_base || progress_expr, progress_base || progress_arg1, progress_base || progress_arg2) - }, - BO::Equality | BO::Inequality => { - // Equal2 on args, forced boolean output - let progress_expr = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?; - let (progress_arg1, progress_arg2) = - self.apply_equal2_constraint(ctx, upcast_id, arg1_id, 0, arg2_id, 0)?; + if template_progress || node_progress { self.queue_node_parent(node_index); } + if arg1_progress { self.queue_node(arg1_index); } + if arg2_progress { self.queue_node(arg2_index); } - (progress_expr, progress_arg1, progress_arg2) - }, - BO::LessThan | BO::GreaterThan | BO::LessThanEqual | BO::GreaterThanEqual => { - // Equal2 on args with numberlike type, forced boolean output - let progress_expr = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?; - let progress_arg_base = self.apply_template_constraint(ctx, arg1_id, &NUMBERLIKE_TEMPLATE)?; - let (progress_arg1, progress_arg2) = - self.apply_equal2_constraint(ctx, upcast_id, arg1_id, 0, arg2_id, 0)?; - - (progress_expr, progress_arg_base || progress_arg1, progress_arg_base || progress_arg2) - }, - BO::Add | BO::Subtract | BO::Multiply | BO::Divide => { - // All equal of number type - let progress_base = self.apply_template_constraint(ctx, upcast_id, &NUMBERLIKE_TEMPLATE)?; - let (progress_expr, progress_arg1, progress_arg2) = - self.apply_equal3_constraint(ctx, upcast_id, arg1_id, arg2_id, 0)?; - - (progress_base || progress_expr, progress_base || progress_arg1, progress_base || progress_arg2) - }, - }; - - debug_log!(" * After:"); - debug_log!(" - Arg1 type [{}]: {}", progress_arg1, self.debug_get_display_name(ctx, arg1_id)); - debug_log!(" - Arg2 type [{}]: {}", progress_arg2, self.debug_get_display_name(ctx, arg2_id)); - debug_log!(" - Expr type [{}]: {}", progress_expr, self.debug_get_display_name(ctx, upcast_id)); - - if progress_expr { self.queue_expr_parent(ctx, upcast_id); } - if progress_arg1 { self.queue_expr(ctx, arg1_id); } - if progress_arg2 { self.queue_expr(ctx, arg2_id); } - - Ok(()) + return Ok(()); } - fn progress_unary_expr(&mut self, ctx: &mut Ctx, id: UnaryExpressionId) -> Result<(), ParseError> { - use UnaryOperator as UO; + fn progress_inference_rule_concatenate(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let rule = node.inference_rule.as_concatenate(); + let arg1_index = rule.argument1_index; + let arg2_index = rule.argument2_index; + + // Two cases: one of the arguments is a string (then all must be), or + // one of the arguments is an array (and all must be arrays). + let (expr_is_str, expr_is_not_str) = self.type_is_certainly_or_certainly_not_string(node_index); + let (arg1_is_str, arg1_is_not_str) = self.type_is_certainly_or_certainly_not_string(arg1_index); + let (arg2_is_str, arg2_is_not_str) = self.type_is_certainly_or_certainly_not_string(arg2_index); + + let someone_is_str = expr_is_str || arg1_is_str || arg2_is_str; + let someone_is_not_str = expr_is_not_str || arg1_is_not_str || arg2_is_not_str; + // Note: this statement is an expression returning the progression bools + let (node_progress, arg1_progress, arg2_progress) = if someone_is_str { + // One of the arguments is a string, then all must be strings + self.apply_equal3_constraint(ctx, node_index, arg1_index, arg2_index, 0)? + } else { + let progress_expr = if someone_is_not_str { + // Output must be a normal array + self.apply_template_constraint(ctx, node_index, &ARRAY_TEMPLATE)? + } else { + // Output may still be anything + self.apply_template_constraint(ctx, node_index, &ARRAYLIKE_TEMPLATE)? + }; - let upcast_id = id.upcast(); - let expr = &ctx.heap[id]; - let arg_id = expr.expression; - - debug_log!("Unary expr '{:?}': {}", expr.operation, upcast_id.index); - debug_log!(" * Before:"); - debug_log!(" - Arg type: {}", self.debug_get_display_name(ctx, arg_id)); - debug_log!(" - Expr type: {}", self.debug_get_display_name(ctx, upcast_id)); - - let (progress_expr, progress_arg) = match expr.operation { - UO::Positive | UO::Negative => { - // Equal types of numeric class - let progress_base = self.apply_template_constraint(ctx, upcast_id, &NUMBERLIKE_TEMPLATE)?; - let (progress_expr, progress_arg) = - self.apply_equal2_constraint(ctx, upcast_id, upcast_id, 0, arg_id, 0)?; - - (progress_base || progress_expr, progress_base || progress_arg) - }, - UO::BitwiseNot => { - // Equal types of integer class - let progress_base = self.apply_template_constraint(ctx, upcast_id, &INTEGERLIKE_TEMPLATE)?; - let (progress_expr, progress_arg) = - self.apply_equal2_constraint(ctx, upcast_id, upcast_id, 0, arg_id, 0)?; + let progress_arg1 = self.apply_template_constraint(ctx, arg1_index, &ARRAYLIKE_TEMPLATE)?; + let progress_arg2 = self.apply_template_constraint(ctx, arg2_index, &ARRAYLIKE_TEMPLATE)?; - (progress_base || progress_expr, progress_base || progress_arg) - }, - UO::LogicalNot => { - // Both bools - let progress_expr = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?; - let progress_arg = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?; - (progress_expr, progress_arg) - } - }; + // If they're all arraylike, then we want the subtype to match + let (subtype_expr, subtype_arg1, subtype_arg2) = + self.apply_equal3_constraint(ctx, node_index, arg1_index, arg2_index, 1)?; - debug_log!(" * After:"); - debug_log!(" - Arg type [{}]: {}", progress_arg, self.debug_get_display_name(ctx, arg_id)); - debug_log!(" - Expr type [{}]: {}", progress_expr, self.debug_get_display_name(ctx, upcast_id)); + (progress_expr || subtype_expr, progress_arg1 || subtype_arg1, progress_arg2 || subtype_arg2) + }; - if progress_expr { self.queue_expr_parent(ctx, upcast_id); } - if progress_arg { self.queue_expr(ctx, arg_id); } + if node_progress { self.queue_node_parent(node_index); } + if arg1_progress { self.queue_node(arg1_index); } + if arg2_progress { self.queue_node(arg2_index); } - Ok(()) + return Ok(()) } - fn progress_indexing_expr(&mut self, ctx: &mut Ctx, id: IndexingExpressionId) -> Result<(), ParseError> { - let upcast_id = id.upcast(); - let expr = &ctx.heap[id]; - let subject_id = expr.subject; - let index_id = expr.index; - - debug_log!("Indexing expr: {}", upcast_id.index); - debug_log!(" * Before:"); - debug_log!(" - Subject type: {}", self.debug_get_display_name(ctx, subject_id)); - debug_log!(" - Index type: {}", self.debug_get_display_name(ctx, index_id)); - debug_log!(" - Expr type: {}", self.debug_get_display_name(ctx, upcast_id)); + fn progress_inference_rule_indexing_expr(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let rule = node.inference_rule.as_indexing_expr(); + let subject_index = rule.subject_index; + let index_index = rule.index_index; // which one? - // Make sure subject is arraylike and index is integerlike - let progress_subject_base = self.apply_template_constraint(ctx, subject_id, &ARRAYLIKE_TEMPLATE)?; - let progress_index = self.apply_template_constraint(ctx, index_id, &INTEGERLIKE_TEMPLATE)?; + // Subject is arraylike, index in integerlike + let subject_template_progress = self.apply_template_constraint(ctx, subject_index, &ARRAYLIKE_TEMPLATE)?; + let index_template_progress = self.apply_template_constraint(ctx, index_index, &INTEGERLIKE_TEMPLATE)?; - // Make sure if output is of T then subject is Array - let (progress_expr, progress_subject) = - self.apply_equal2_constraint(ctx, upcast_id, upcast_id, 0, subject_id, 1)?; + // If subject is type `Array`, then expr type is `T` + let (node_progress, subject_progress) = + self.apply_equal2_constraint(ctx, node_index, node_index, 0, subject_index, 1)?; - debug_log!(" * After:"); - debug_log!(" - Subject type [{}]: {}", progress_subject_base || progress_subject, self.debug_get_display_name(ctx, subject_id)); - debug_log!(" - Index type [{}]: {}", progress_index, self.debug_get_display_name(ctx, index_id)); - debug_log!(" - Expr type [{}]: {}", progress_expr, self.debug_get_display_name(ctx, upcast_id)); + if node_progress { self.queue_node_parent(node_index); } + if subject_template_progress || subject_progress { self.queue_node(subject_index); } + if index_template_progress { self.queue_node(index_index); } - if progress_expr { self.queue_expr_parent(ctx, upcast_id); } - if progress_subject_base || progress_subject { self.queue_expr(ctx, subject_id); } - if progress_index { self.queue_expr(ctx, index_id); } - - Ok(()) + return Ok(()); } - fn progress_slicing_expr(&mut self, ctx: &mut Ctx, id: SlicingExpressionId) -> Result<(), ParseError> { - let upcast_id = id.upcast(); - let expr = &ctx.heap[id]; - let subject_id = expr.subject; - let from_id = expr.from_index; - let to_id = expr.to_index; - - debug_log!("Slicing expr: {}", upcast_id.index); - debug_log!(" * Before:"); - debug_log!(" - Subject type: {}", self.debug_get_display_name(ctx, subject_id)); - debug_log!(" - FromIdx type: {}", self.debug_get_display_name(ctx, from_id)); - debug_log!(" - ToIdx type: {}", self.debug_get_display_name(ctx, to_id)); - debug_log!(" - Expr type: {}", self.debug_get_display_name(ctx, upcast_id)); - - // Make sure subject is arraylike and indices are of equal integerlike - let progress_subject_base = self.apply_template_constraint(ctx, subject_id, &ARRAYLIKE_TEMPLATE)?; - let progress_idx_base = self.apply_template_constraint(ctx, from_id, &INTEGERLIKE_TEMPLATE)?; - let (progress_from, progress_to) = self.apply_equal2_constraint(ctx, upcast_id, from_id, 0, to_id, 0)?; - - let (progress_expr, progress_subject) = match self.type_is_certainly_or_certainly_not_string(ctx, subject_id) { - (true, _) => { - // Certainly a string - (self.apply_forced_constraint(ctx, upcast_id, &STRING_TEMPLATE)?, false) - }, - (_, true) => { - // Certainly not a string - let progress_expr_base = self.apply_template_constraint(ctx, upcast_id, &SLICE_TEMPLATE)?; - let (progress_expr, progress_subject) = - self.apply_equal2_constraint(ctx, upcast_id, upcast_id, 1, subject_id, 1)?; + fn progress_inference_rule_slicing_expr(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let rule = node.inference_rule.as_slicing_expr(); + let subject_index = rule.subject_index; + let from_index_index = rule.from_index; + let to_index_index = rule.to_index; + + debug_log!("Rule slicing [node: {}, expr: {}]", node_index, node.expr_id.index); + + // Subject is arraylike, indices are integerlike + let subject_template_progress = self.apply_template_constraint(ctx, subject_index, &ARRAYLIKE_TEMPLATE)?; + let from_template_progress = self.apply_template_constraint(ctx, from_index_index, &INTEGERLIKE_TEMPLATE)?; + let to_template_progress = self.apply_template_constraint(ctx, to_index_index, &INTEGERLIKE_TEMPLATE)?; + let (from_index_progress, to_index_progress) = + self.apply_equal2_constraint(ctx, node_index, from_index_index, 0, to_index_index, 0)?; + + // Same as array indexing: result depends on whether subject is string + // or array + let (is_string, is_not_string) = self.type_is_certainly_or_certainly_not_string(node_index); + let (node_progress, subject_progress) = if is_string { + // Certainly a string + ( + self.apply_forced_constraint(ctx, node_index, &STRING_TEMPLATE)?, + false + ) + } else if is_not_string { + // Certainly not a string, apply template constraint. Then make sure + // that if we have an `Array`, that the slice produces `Slice` + let node_template_progress = self.apply_template_constraint(ctx, node_index, &SLICE_TEMPLATE)?; + let (node_progress, subject_progress) = + self.apply_equal2_constraint(ctx, node_index, node_index, 1, subject_index, 1)?; - (progress_expr_base || progress_expr, progress_subject) - }, - _ => { - // Could be anything, at least attempt to progress subtype - let progress_expr_base = self.apply_template_constraint(ctx, upcast_id, &ARRAYLIKE_TEMPLATE)?; - let (progress_expr, progress_subject) = - self.apply_equal2_constraint(ctx, upcast_id, upcast_id, 1, subject_id, 1)?; + ( + node_template_progress || node_progress, + subject_progress + ) + } else { + // Not sure yet + let node_template_progress = self.apply_template_constraint(ctx, node_index, &ARRAYLIKE_TEMPLATE)?; + let (node_progress, subject_progress) = + self.apply_equal2_constraint(ctx, node_index, node_index, 1, subject_index, 1)?; - (progress_expr_base || progress_expr, progress_subject) - } + ( + node_template_progress || node_progress, + subject_progress + ) }; - debug_log!(" * After:"); - debug_log!(" - Subject type [{}]: {}", progress_subject_base || progress_subject, self.debug_get_display_name(ctx, subject_id)); - debug_log!(" - FromIdx type [{}]: {}", progress_idx_base || progress_from, self.debug_get_display_name(ctx, from_id)); - debug_log!(" - ToIdx type [{}]: {}", progress_idx_base || progress_to, self.debug_get_display_name(ctx, to_id)); - debug_log!(" - Expr type [{}]: {}", progress_expr, self.debug_get_display_name(ctx, upcast_id)); - - if progress_expr { self.queue_expr_parent(ctx, upcast_id); } - if progress_subject_base || progress_subject { self.queue_expr(ctx, subject_id); } - if progress_idx_base || progress_from { self.queue_expr(ctx, from_id); } - if progress_idx_base || progress_to { self.queue_expr(ctx, to_id); } + if node_progress { self.queue_node_parent(node_index); } + if subject_template_progress || subject_progress { self.queue_node(subject_index); } + if from_template_progress || from_index_progress { self.queue_node(from_index_index); } + if to_template_progress || to_index_progress { self.queue_node(to_index_index); } - Ok(()) + return Ok(()); } - fn progress_select_expr(&mut self, ctx: &mut Ctx, id: SelectExpressionId) -> Result<(), ParseError> { - let upcast_id = id.upcast(); - - debug_log!("Select expr: {}", upcast_id.index); - debug_log!(" * Before:"); - debug_log!(" - Subject type: {}", self.debug_get_display_name(ctx, ctx.heap[id].subject)); - debug_log!(" - Expr type: {}", self.debug_get_display_name(ctx, upcast_id)); - - let subject_id = ctx.heap[id].subject; - let subject_expr_idx = ctx.heap[subject_id].get_unique_id_in_definition(); - let select_expr = &ctx.heap[id]; - let expr_idx = select_expr.unique_id_in_definition; - - let infer_expr = &self.expr_types[expr_idx as usize]; - let extra_idx = infer_expr.extra_data_idx; + fn progress_inference_rule_select_struct_field(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let rule = node.inference_rule.as_select_struct_field(); - fn try_get_definition_id_from_inference_type<'a>(types: &'a TypeTable, infer_type: &InferenceType) -> Result, ()> { - for part in &infer_type.parts { - if part.is_marker() || !part.is_concrete() { - continue; - } + let subject_index = rule.subject_index; + let selected_field = rule.selected_field.clone(); - // Part is concrete, check if it is an instance of something - if let InferenceTypePart::Instance(definition_id, _num_sub) = part { - // Lookup type definition and ensure the specified field - // name exists on the struct - let definition = types.get_base_definition(definition_id); - debug_assert!(definition.is_some()); - let definition = definition.unwrap(); + fn get_definition_id_from_inference_type(inference_type: &InferenceType) -> Result, ()> { + for part in inference_type.parts.iter() { + if part.is_marker() { continue; } + if !part.is_concrete() { break; } - return Ok(Some(definition)) + if let InferenceTypePart::Instance(definition_id, _) = part { + return Ok(Some(*definition_id)); } else { - // Expected an instance of something return Err(()) } } - // Nothing is concrete yet - Ok(None) + // Nothing is known yet + return Ok(None); } - fn try_get_tuple_size_from_inference_type(infer_type: &InferenceType) -> Result, ()> { - for part in &infer_type.parts { - if part.is_marker() || !part.is_concrete() { - continue; - } - - if let InferenceTypePart::Tuple(size) = part { - return Ok(Some(*size)); - } else { - // Expected a tuple - return Err(()); - } - } - - // Type is not "defined enough" yet - Ok(None) - } + if node.field_index < 0 { + // Don't know the subject definition, hence the field yet. Try to + // determine it. + let subject_node = &self.infer_nodes[subject_index]; + match get_definition_id_from_inference_type(&subject_node.expr_type) { + Ok(Some(definition_id)) => { + // Determined definition of subject for the first time. + let base_definition = ctx.types.get_base_definition(&definition_id).unwrap(); + let struct_definition = if let DefinedTypeVariant::Struct(struct_definition) = &base_definition.definition { + struct_definition + } else { + return Err(ParseError::new_error_at_span( + &ctx.module().source, selected_field.span, format!( + "Can only apply field access to structs, got a subject of type '{}'", + subject_node.expr_type.display_name(&ctx.heap) + ) + )); + }; - let (progress_subject, progress_expr) = match &select_expr.kind { - SelectKind::StructField(field_name) => { - // Handle select of a struct's field - if infer_expr.field_or_monomorph_idx < 0 { - // We don't know the field or the definition it is pointing to yet - // Not yet known, check if we can determine it - let subject_type = &self.expr_types[subject_expr_idx as usize].expr_type; - let type_def = try_get_definition_id_from_inference_type(&ctx.types, subject_type); - - match type_def { - Ok(Some(type_def)) => { - // Subject type is known, check if it is a - // struct and the field exists on the struct - let struct_def = if let DefinedTypeVariant::Struct(struct_def) = &type_def.definition { - struct_def - } else { - return Err(ParseError::new_error_at_span( - &ctx.module().source, field_name.span, format!( - "Can only apply field access to structs, got a subject of type '{}'", - subject_type.display_name(&ctx.heap) - ) - )); - }; - - let mut struct_def_id = None; - - for (field_def_idx, field_def) in struct_def.fields.iter().enumerate() { - if field_def.identifier == *field_name { - // Set field definition and index - let infer_expr = &mut self.expr_types[expr_idx as usize]; - infer_expr.field_or_monomorph_idx = field_def_idx as i32; - struct_def_id = Some(type_def.ast_definition); - break; - } - } - - if struct_def_id.is_none() { - let ast_struct_def = ctx.heap[type_def.ast_definition].as_struct(); - return Err(ParseError::new_error_at_span( - &ctx.module().source, field_name.span, format!( - "this field does not exist on the struct '{}'", - ast_struct_def.identifier.value.as_str() - ) - )) - } - - // Encountered definition and field index for the - // first time - self.insert_initial_select_polymorph_data(ctx, id, struct_def_id.unwrap()); - }, - Ok(None) => { - // Type of subject is not yet known, so we - // cannot make any progress yet - return Ok(()) - }, - Err(()) => { - return Err(ParseError::new_error_at_span( - &ctx.module().source, field_name.span, format!( - "Can only apply field access to structs, got a subject of type '{}'", - subject_type.display_name(&ctx.heap) - ) - )); + // Seek the field that is referenced by the select + // expression + let mut field_found = false; + for (field_index, field) in struct_definition.fields.iter().enumerate() { + if field.identifier.value == selected_field.value { + // Found the field of interest + field_found = true; + let node = &mut self.infer_nodes[node_index]; + node.field_index = field_index as i32; + break; } } - } - - // If here then field index is known, and the referenced struct type - // information is inserted into `extra_data`. Check to see if we can - // do some mutual inference. - let poly_data = &mut self.extra_data[extra_idx as usize]; - let mut poly_progress = HashSet::new(); // TODO: @Performance - // Apply to struct's type - let signature_type: *mut _ = &mut poly_data.embedded[0]; - let subject_type: *mut _ = &mut self.expr_types[subject_expr_idx as usize].expr_type; + if !field_found { + let struct_definition = ctx.heap[definition_id].as_struct(); + return Err(ParseError::new_error_at_span( + &ctx.module().source, selected_field.span, format!( + "this field does not exist on the struct '{}'", + struct_definition.identifier.value.as_str() + ) + )); + } - let (_, progress_subject) = Self::apply_equal2_signature_constraint( - ctx, upcast_id, Some(subject_id), poly_data, &mut poly_progress, - signature_type, 0, subject_type, 0 - )?; + // Insert the initial data needed to infer polymorphic + // fields + let extra_index = self.insert_initial_select_polymorph_data(ctx, node_index, definition_id); + let node = &mut self.infer_nodes[node_index]; + node.poly_data_index = extra_index; + }, + Ok(None) => { + // We don't know what to do yet, because we don't know the + // subject type yet. + return Ok(()) + }, + Err(()) => { + return Err(ParseError::new_error_at_span( + &ctx.module().source, rule.selected_field.span, format!( + "Can only apply field access to structs, got a subject of type '{}'", + subject_node.expr_type.display_name(&ctx.heap) + ) + )); + }, + } + } - if progress_subject { - self.expr_queued.push_back(subject_expr_idx); - } + // If here then the field index is known, hence we can start inferring + // the type of the selected field + let field_expr_id = self.infer_nodes[node_index].expr_id; + let subject_expr_id = self.infer_nodes[subject_index].expr_id; + let mut poly_progress_section = self.poly_progress_buffer.start_section(); - // Apply to field's type - let signature_type: *mut _ = &mut poly_data.returned; - let expr_type: *mut _ = &mut self.expr_types[expr_idx as usize].expr_type; + let (_, progress_subject_1) = self.apply_polydata_equal2_constraint( + ctx, node_index, subject_expr_id, "selected struct's", + PolyDataTypeIndex::Associated(0), 0, subject_index, 0, &mut poly_progress_section + )?; + let (_, progress_field_1) = self.apply_polydata_equal2_constraint( + ctx, node_index, field_expr_id, "selected field's", + PolyDataTypeIndex::Returned, 0, node_index, 0, &mut poly_progress_section + )?; - let (_, progress_expr) = Self::apply_equal2_signature_constraint( - ctx, upcast_id, None, poly_data, &mut poly_progress, - signature_type, 0, expr_type, 0 - )?; + // Maybe make progress on types due to inferred polymorphic variables + let progress_subject_2 = self.apply_polydata_polyvar_constraint( + ctx, node_index, PolyDataTypeIndex::Associated(0), subject_index, &poly_progress_section + ); + let progress_field_2 = self.apply_polydata_polyvar_constraint( + ctx, node_index, PolyDataTypeIndex::Returned, node_index, &poly_progress_section + ); - if progress_expr { - if let Some(parent_id) = ctx.heap[upcast_id].parent_expr_id() { - let parent_idx = ctx.heap[parent_id].get_unique_id_in_definition(); - self.expr_queued.push_back(parent_idx); - } - } + if progress_subject_1 || progress_subject_2 { self.queue_node(subject_index); } + if progress_field_1 || progress_field_2 { self.queue_node_parent(node_index); } - // Reapply progress in polymorphic variables to struct's type - let signature_type: *mut _ = &mut poly_data.embedded[0]; - let subject_type: *mut _ = &mut self.expr_types[subject_expr_idx as usize].expr_type; + poly_progress_section.forget(); + self.finish_polydata_constraint(node_index); + return Ok(()) + } - let progress_subject = Self::apply_equal2_polyvar_constraint( - poly_data, &poly_progress, signature_type, subject_type - ); + fn progress_inference_rule_select_tuple_member(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let rule = node.inference_rule.as_select_tuple_member(); + let subject_index = rule.subject_index; + let tuple_member_index = rule.selected_index; + + if node.field_index < 0 { + let subject_type = &self.infer_nodes[subject_index].expr_type; + let tuple_size = get_tuple_size_from_inference_type(subject_type); + let tuple_size = match tuple_size { + Ok(Some(tuple_size)) => { + tuple_size + }, + Ok(None) => { + // We can't infer anything yet + return Ok(()) + }, + Err(()) => { + let select_expr_span = ctx.heap[node.expr_id].full_span(); + return Err(ParseError::new_error_at_span( + &ctx.module().source, select_expr_span, format!( + "tuple element select cannot be applied to a subject of type '{}'", + subject_type.display_name(&ctx.heap) + ) + )); + } + }; - let signature_type: *mut _ = &mut poly_data.returned; - let expr_type: *mut _ = &mut self.expr_types[expr_idx as usize].expr_type; + // If here then we at least have the tuple size. Now check if the + // index doesn't exceed that size. + if tuple_member_index >= tuple_size as u64 { + let select_expr_span = ctx.heap[node.expr_id].full_span(); + return Err(ParseError::new_error_at_span( + &ctx.module().source, select_expr_span, format!( + "element index {} is out of bounds, tuple has {} elements", + tuple_member_index, tuple_size + ) + )); + } - let progress_expr = Self::apply_equal2_polyvar_constraint( - poly_data, &poly_progress, signature_type, expr_type - ); + // Within bounds, set index on the type inference node + let node = &mut self.infer_nodes[node_index]; + node.field_index = tuple_member_index as i32; + } - (progress_subject, progress_expr) - }, - SelectKind::TupleMember(member_index) => { - let member_index = *member_index; - - if infer_expr.field_or_monomorph_idx < 0 { - // We don't know what kind of tuple we're accessing yet - let subject_type = &self.expr_types[subject_expr_idx as usize].expr_type; - let tuple_size = try_get_tuple_size_from_inference_type(subject_type); - - match tuple_size { - Ok(Some(enum_size)) => { - // Make sure we don't access an element outside of - // the tuple's bounds - if member_index >= enum_size as u64 { - return Err(ParseError::new_error_at_span( - &ctx.module().source, select_expr.full_span, format!( - "element index {} is out of bounds, tuple has {} elements", - member_index, enum_size - ) - )); - } - - // Within bounds, so set the index (such that we - // will not perform this lookup again) - let infer_expr = &mut self.expr_types[expr_idx as usize]; - infer_expr.field_or_monomorph_idx = member_index as i32; - }, - Ok(None) => { - // Nothing is known about the tuple yet - return Ok(()); - }, - Err(()) => { - return Err(ParseError::new_error_at_span( - &ctx.module().source, select_expr.full_span, format!( - "Can only apply tuple element selection to tuples, got a subject of type '{}'", - subject_type.display_name(&ctx.heap) - ) - )); - } - } - } + // If here then we know we can use `tuple_member_index`. We need to keep + // computing the offset to the subtype, as its value changes during + // inference + let subject_type = &self.infer_nodes[subject_index].expr_type; + let mut selected_member_start_index = 1; // start just after the InferenceTypeElement::Tuple + for _ in 0..tuple_member_index { + selected_member_start_index = InferenceType::find_subtree_end_idx(&subject_type.parts, selected_member_start_index); + } - // If here then we know which member we're accessing. So seek - // that member in the subject type and apply inference. - let subject_type = &self.expr_types[subject_expr_idx as usize].expr_type; - let mut member_start_idx = 1; - for _ in 0..member_index { - member_start_idx = InferenceType::find_subtree_end_idx(&subject_type.parts, member_start_idx); - } + let (progress_member, progress_subject) = self.apply_equal2_constraint( + ctx, node_index, node_index, 0, subject_index, selected_member_start_index + )?; - let (progress_expr, progress_subject) = self.apply_equal2_constraint( - ctx, upcast_id, upcast_id, 0, subject_id, member_start_idx - )?; + if progress_member { self.queue_node_parent(node_index); } + if progress_subject { self.queue_node(subject_index); } - (progress_subject, progress_expr) - }, - }; + return Ok(()); + } - if progress_subject { self.queue_expr(ctx, subject_id); } - if progress_expr { self.queue_expr_parent(ctx, upcast_id); } + fn progress_inference_rule_literal_struct(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let node_expr_id = node.expr_id; + let rule = node.inference_rule.as_literal_struct(); + + // For each of the fields in the literal struct, apply the type equality + // constraint. If the literal is polymorphic, then we try to progress + // their types during this process + let element_indices_section = self.index_buffer.start_section_initialized(&rule.element_indices); + let mut poly_progress_section = self.poly_progress_buffer.start_section(); + for (field_index, field_node_index) in element_indices_section.iter_copied().enumerate() { + let field_expr_id = self.infer_nodes[field_node_index].expr_id; + let (_, progress_field) = self.apply_polydata_equal2_constraint( + ctx, node_index, field_expr_id, "struct field's", + PolyDataTypeIndex::Associated(field_index), 0, + field_node_index, 0, &mut poly_progress_section + )?; - debug_log!(" * After:"); - debug_log!(" - Subject type [{}]: {}", progress_subject, self.debug_get_display_name(ctx, subject_id)); - debug_log!(" - Expr type [{}]: {}", progress_expr, self.debug_get_display_name(ctx, upcast_id)); + if progress_field { self.queue_node(field_node_index); } + } - Ok(()) - } + // Now we do the same thing for the struct literal expression (the type + // of the struct itself). + let (_, progress_literal_1) = self.apply_polydata_equal2_constraint( + ctx, node_index, node_expr_id, "struct literal's", + PolyDataTypeIndex::Returned, 0, node_index, 0, &mut poly_progress_section + )?; - fn progress_literal_expr(&mut self, ctx: &mut Ctx, id: LiteralExpressionId) -> Result<(), ParseError> { - let upcast_id = id.upcast(); - let expr = &ctx.heap[id]; - let expr_idx = expr.unique_id_in_definition; - let extra_idx = self.expr_types[expr_idx as usize].extra_data_idx; + // And the other way around: if any of our polymorphic variables are + // more specific then they were before, then we forward that information + // back to our struct/fields. + for (field_index, field_node_index) in element_indices_section.iter_copied().enumerate() { + let progress_field = self.apply_polydata_polyvar_constraint( + ctx, node_index, PolyDataTypeIndex::Associated(field_index), + field_node_index, &poly_progress_section + ); - debug_log!("Literal expr: {}", upcast_id.index); - debug_log!(" * Before:"); - debug_log!(" - Expr type: {}", self.debug_get_display_name(ctx, upcast_id)); + if progress_field { self.queue_node(field_node_index); } + } - let progress_expr = match &expr.value { - Literal::Null => { - self.apply_template_constraint(ctx, upcast_id, &MESSAGE_TEMPLATE)? - }, - Literal::Integer(_) => { - self.apply_template_constraint(ctx, upcast_id, &INTEGERLIKE_TEMPLATE)? - }, - Literal::True | Literal::False => { - self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)? - }, - Literal::Character(_) => { - self.apply_forced_constraint(ctx, upcast_id, &CHARACTER_TEMPLATE)? - }, - Literal::String(_) => { - self.apply_forced_constraint(ctx, upcast_id, &STRING_TEMPLATE)? - }, - Literal::Struct(data) => { - let extra = &mut self.extra_data[extra_idx as usize]; - for _poly in &extra.poly_vars { - debug_log!(" * Poly: {}", _poly.display_name(&ctx.heap)); - } - let mut poly_progress = HashSet::new(); - debug_assert_eq!(extra.embedded.len(), data.fields.len()); - - debug_log!(" * During (inferring types from fields and struct type):"); - - // Mutually infer field signature/expression types - for (field_idx, field) in data.fields.iter().enumerate() { - let field_expr_id = field.value; - let field_expr_idx = ctx.heap[field_expr_id].get_unique_id_in_definition(); - let signature_type: *mut _ = &mut extra.embedded[field_idx]; - let field_type: *mut _ = &mut self.expr_types[field_expr_idx as usize].expr_type; - let (_, progress_arg) = Self::apply_equal2_signature_constraint( - ctx, upcast_id, Some(field_expr_id), extra, &mut poly_progress, - signature_type, 0, field_type, 0 - )?; - - debug_log!( - " - Field {} type | sig: {}, field: {}", field_idx, - unsafe{&*signature_type}.display_name(&ctx.heap), - unsafe{&*field_type}.display_name(&ctx.heap) - ); + let progress_literal_2 = self.apply_polydata_polyvar_constraint( + ctx, node_index, PolyDataTypeIndex::Returned, + node_index, &poly_progress_section + ); - if progress_arg { - self.expr_queued.push_back(field_expr_idx); - } - } + if progress_literal_1 || progress_literal_2 { self.queue_node_parent(node_index); } - debug_log!(" - Field poly progress | {:?}", poly_progress); + poly_progress_section.forget(); + element_indices_section.forget(); - // Same for the type of the struct itself - let signature_type: *mut _ = &mut extra.returned; - let expr_type: *mut _ = &mut self.expr_types[expr_idx as usize].expr_type; - let (_, progress_expr) = Self::apply_equal2_signature_constraint( - ctx, upcast_id, None, extra, &mut poly_progress, - signature_type, 0, expr_type, 0 - )?; + self.finish_polydata_constraint(node_index); + return Ok(()) + } - debug_log!( - " - Ret type | sig: {}, expr: {}", - unsafe{&*signature_type}.display_name(&ctx.heap), - unsafe{&*expr_type}.display_name(&ctx.heap) - ); - debug_log!(" - Ret poly progress | {:?}", poly_progress); + fn progress_inference_rule_literal_enum(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let node_expr_id = node.expr_id; + let mut poly_progress_section = self.poly_progress_buffer.start_section(); - if progress_expr { - // TODO: @cleanup, cannot call utility self.queue_parent thingo - if let Some(parent_id) = ctx.heap[upcast_id].parent_expr_id() { - let parent_idx = ctx.heap[parent_id].get_unique_id_in_definition(); - self.expr_queued.push_back(parent_idx); - } - } + // An enum literal type is simply, well, the enum's type. However, it + // might still have polymorphic variables, hence the use of `PolyData`. + let (_, progress_literal_1) = self.apply_polydata_equal2_constraint( + ctx, node_index, node_expr_id, "enum literal's", + PolyDataTypeIndex::Returned, 0, node_index, 0, &mut poly_progress_section + )?; - // Check which expressions use the polymorphic arguments. If the - // polymorphic variables have been progressed then we try to - // progress them inside the expression as well. - debug_log!(" * During (reinferring from progressed polyvars):"); - - // For all field expressions - for field_idx in 0..extra.embedded.len() { - // Note: fields in extra.embedded are in the same order as - // they are specified in the literal. Whereas - // `data.fields[...].field_idx` points to the field in the - // struct definition. - let signature_type: *mut _ = &mut extra.embedded[field_idx]; - let field_expr_id = data.fields[field_idx].value; - let field_expr_idx = ctx.heap[field_expr_id].get_unique_id_in_definition(); - let field_type: *mut _ = &mut self.expr_types[field_expr_idx as usize].expr_type; - - let progress_arg = Self::apply_equal2_polyvar_constraint( - extra, &poly_progress, signature_type, field_type - ); + let progress_literal_2 = self.apply_polydata_polyvar_constraint( + ctx, node_index, PolyDataTypeIndex::Returned, node_index, &poly_progress_section + ); - debug_log!( - " - Field {} type | sig: {}, field: {}", field_idx, - unsafe{&*signature_type}.display_name(&ctx.heap), - unsafe{&*field_type}.display_name(&ctx.heap) - ); - if progress_arg { - self.expr_queued.push_back(field_expr_idx); - } - } - - // For the return type - let signature_type: *mut _ = &mut extra.returned; - let expr_type: *mut _ = &mut self.expr_types[expr_idx as usize].expr_type; + if progress_literal_1 || progress_literal_2 { self.queue_node_parent(node_index); } - let progress_expr = Self::apply_equal2_polyvar_constraint( - extra, &poly_progress, signature_type, expr_type - ); + poly_progress_section.forget(); + self.finish_polydata_constraint(node_index); + return Ok(()); + } - progress_expr - }, - Literal::Enum(_) => { - let extra = &mut self.extra_data[extra_idx as usize]; - for _poly in &extra.poly_vars { - debug_log!(" * Poly: {}", _poly.display_name(&ctx.heap)); - } - let mut poly_progress = HashSet::new(); - - debug_log!(" * During (inferring types from return type)"); - - let signature_type: *mut _ = &mut extra.returned; - let expr_type: *mut _ = &mut self.expr_types[expr_idx as usize].expr_type; - let (_, progress_expr) = Self::apply_equal2_signature_constraint( - ctx, upcast_id, None, extra, &mut poly_progress, - signature_type, 0, expr_type, 0 - )?; + fn progress_inference_rule_literal_union(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let node_expr_id = node.expr_id; + let rule = node.inference_rule.as_literal_union(); + + // Infer type of any embedded values in the union variant. At the same + // time progress the polymorphic variables associated with the union. + let element_indices_section = self.index_buffer.start_section_initialized(&rule.element_indices); + let mut poly_progress_section = self.poly_progress_buffer.start_section(); + + for (embedded_index, embedded_node_index) in element_indices_section.iter_copied().enumerate() { + let embedded_node_expr_id = self.infer_nodes[embedded_node_index].expr_id; + let (_, progress_embedded) = self.apply_polydata_equal2_constraint( + ctx, node_index, embedded_node_expr_id, "embedded value's", + PolyDataTypeIndex::Associated(embedded_index), 0, + embedded_node_index, 0, &mut poly_progress_section + )?; - debug_log!( - " - Ret type | sig: {}, expr: {}", - unsafe{&*signature_type}.display_name(&ctx.heap), - unsafe{&*expr_type}.display_name(&ctx.heap) - ); + if progress_embedded { self.queue_node(embedded_node_index); } + } - if progress_expr { - // TODO: @cleanup - if let Some(parent_id) = ctx.heap[upcast_id].parent_expr_id() { - let parent_idx = ctx.heap[parent_id].get_unique_id_in_definition(); - self.expr_queued.push_back(parent_idx); - } - } + let (_, progress_literal_1) = self.apply_polydata_equal2_constraint( + ctx, node_index, node_expr_id, "union's", + PolyDataTypeIndex::Returned, 0, node_index, 0, &mut poly_progress_section + )?; - debug_log!(" * During (reinferring from progress polyvars):"); - let progress_expr = Self::apply_equal2_polyvar_constraint( - extra, &poly_progress, signature_type, expr_type - ); + // Propagate progress in the polymorphic variables to the expressions + // that constitute the union literal. + for (embedded_index, embedded_node_index) in element_indices_section.iter_copied().enumerate() { + let progress_embedded = self.apply_polydata_polyvar_constraint( + ctx, node_index, PolyDataTypeIndex::Associated(embedded_index), + embedded_node_index, &poly_progress_section + ); - progress_expr - }, - Literal::Union(data) => { - let extra = &mut self.extra_data[extra_idx as usize]; - for _poly in &extra.poly_vars { - debug_log!(" * Poly: {}", _poly.display_name(&ctx.heap)); - } - let mut poly_progress = HashSet::new(); - debug_assert_eq!(extra.embedded.len(), data.values.len()); - - debug_log!(" * During (inferring types from variant values and union type):"); - - // Mutually infer union variant values - for (value_idx, value_expr_id) in data.values.iter().enumerate() { - let value_expr_id = *value_expr_id; - let value_expr_idx = ctx.heap[value_expr_id].get_unique_id_in_definition(); - let signature_type: *mut _ = &mut extra.embedded[value_idx]; - let value_type: *mut _ = &mut self.expr_types[value_expr_idx as usize].expr_type; - let (_, progress_arg) = Self::apply_equal2_signature_constraint( - ctx, upcast_id, Some(value_expr_id), extra, &mut poly_progress, - signature_type, 0, value_type, 0 - )?; - - debug_log!( - " - Value {} type | sig: {}, field: {}", value_idx, - unsafe{&*signature_type}.display_name(&ctx.heap), - unsafe{&*value_type}.display_name(&ctx.heap) - ); + if progress_embedded { self.queue_node(embedded_node_index); } + } - if progress_arg { - self.expr_queued.push_back(value_expr_idx); - } - } + let progress_literal_2 = self.apply_polydata_polyvar_constraint( + ctx, node_index, PolyDataTypeIndex::Returned, node_index, &poly_progress_section + ); - debug_log!(" - Field poly progress | {:?}", poly_progress); + if progress_literal_1 || progress_literal_2 { self.queue_node_parent(node_index); } - // Infer type of union itself - let signature_type: *mut _ = &mut extra.returned; - let expr_type: *mut _ = &mut self.expr_types[expr_idx as usize].expr_type; - let (_, progress_expr) = Self::apply_equal2_signature_constraint( - ctx, upcast_id, None, extra, &mut poly_progress, - signature_type, 0, expr_type, 0 - )?; + poly_progress_section.forget(); + self.finish_polydata_constraint(node_index); + return Ok(()); + } - debug_log!( - " - Ret type | sig: {}, expr: {}", - unsafe{&*signature_type}.display_name(&ctx.heap), - unsafe{&*expr_type}.display_name(&ctx.heap) - ); - debug_log!(" - Ret poly progress | {:?}", poly_progress); + fn progress_inference_rule_literal_array(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let rule = node.inference_rule.as_literal_array(); - if progress_expr { - // TODO: @cleanup, borrowing rules - if let Some(parent_id) = ctx.heap[upcast_id].parent_expr_id() { - let parent_idx = ctx.heap[parent_id].get_unique_id_in_definition(); - self.expr_queued.push_back(parent_idx); - } - } + // Apply equality rule to all of the elements that form the array + let argument_node_indices = self.index_buffer.start_section_initialized(&rule.element_indices); + let mut argument_progress_section = self.bool_buffer.start_section(); + self.apply_equal_n_constraint(ctx, node_index, &argument_node_indices, &mut argument_progress_section)?; - debug_log!(" * During (reinferring from progress polyvars):"); - - // For all embedded values of the union variant - for value_idx in 0..extra.embedded.len() { - let signature_type: *mut _ = &mut extra.embedded[value_idx]; - let value_expr_id = data.values[value_idx]; - let value_expr_idx = ctx.heap[value_expr_id].get_unique_id_in_definition(); - let value_type: *mut _ = &mut self.expr_types[value_expr_idx as usize].expr_type; - - let progress_arg = Self::apply_equal2_polyvar_constraint( - extra, &poly_progress, signature_type, value_type - ); + debug_assert_eq!(argument_node_indices.len(), argument_progress_section.len()); + for argument_index in 0..argument_node_indices.len() { + let argument_node_index = argument_node_indices[argument_index]; + let progress = argument_progress_section[argument_index]; - debug_log!( - " - Value {} type | sig: {}, value: {}", value_idx, - unsafe{&*signature_type}.display_name(&ctx.heap), - unsafe{&*value_type}.display_name(&ctx.heap) - ); - if progress_arg { - self.expr_queued.push_back(value_expr_idx); - } - } + if progress { self.queue_node(argument_node_index); } + } - // And for the union type itself - let signature_type: *mut _ = &mut extra.returned; - let expr_type: *mut _ = &mut self.expr_types[expr_idx as usize].expr_type; + // If elements are of type `T`, then the array is of type `Array`, so: + let mut progress_literal = self.apply_template_constraint(ctx, node_index, &ARRAY_TEMPLATE)?; + if argument_node_indices.len() != 0 { + let argument_node_index = argument_node_indices[0]; + let (progress_literal_inner, progress_argument) = self.apply_equal2_constraint( + ctx, node_index, node_index, 1, argument_node_index, 0 + )?; - let progress_expr = Self::apply_equal2_polyvar_constraint( - extra, &poly_progress, signature_type, expr_type - ); + progress_literal = progress_literal || progress_literal_inner; - progress_expr - }, - Literal::Array(data) => { - let expr_elements = self.expr_buffer.start_section_initialized(data.as_slice()); - debug_log!("Array expr ({} elements): {}", expr_elements.len(), upcast_id.index); - debug_log!(" * Before:"); - debug_log!(" - Expr type: {}", self.debug_get_display_name(ctx, upcast_id)); - - // All elements should have an equal type - let mut bool_buffer = self.bool_buffer.start_section(); - self.apply_equal_n_constraint(ctx, upcast_id, &expr_elements, &mut bool_buffer)?; - for (progress_arg, arg_id) in bool_buffer.iter_copied().zip(expr_elements.iter_copied()) { - if progress_arg { - self.queue_expr(ctx, arg_id); - } - } + // It is possible that the `Array` has a more progress `T` then + // the arguments. So in the case we progress our argument type we + // simply queue this rule again + if progress_argument { self.queue_node(node_index); } + } - // And the output should be an array of the element types - let mut progress_expr = self.apply_template_constraint(ctx, upcast_id, &ARRAY_TEMPLATE)?; - if expr_elements.len() != 0 { - let first_arg_id = expr_elements[0]; - let (inner_expr_progress, arg_progress) = self.apply_equal2_constraint( - ctx, upcast_id, upcast_id, 1, first_arg_id, 0 - )?; + argument_node_indices.forget(); + argument_progress_section.forget(); - progress_expr = progress_expr || inner_expr_progress; + if progress_literal { self.queue_node_parent(node_index); } + return Ok(()); + } - // Note that if the array type progressed the type of the arguments, - // then we should enqueue this progression function again - // TODO: @fix Make apply_equal_n accept a start idx as well - if arg_progress { self.queue_expr(ctx, upcast_id); } - } + fn progress_inference_rule_literal_tuple(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let rule = node.inference_rule.as_literal_tuple(); - debug_log!(" * After:"); - debug_log!(" - Expr type [{}]: {}", progress_expr, self.debug_get_display_name(ctx, upcast_id)); + let element_indices = self.index_buffer.start_section_initialized(&rule.element_indices); - expr_elements.forget(); - progress_expr - }, - Literal::Tuple(data) => { - let expr_elements = self.expr_buffer.start_section_initialized(data.as_slice()); - debug_log!("Tuple expr ({} elements): {}", expr_elements.len(), upcast_id.index); - debug_log!(" * Before:"); - debug_log!(" - Expr type: {}", self.debug_get_display_name(ctx, upcast_id)); - - // Initial tuple constraint - let num_members = expr_elements.len(); - let mut initial_type = Vec::with_capacity(num_members + 1); // TODO: @performance - initial_type.push(InferenceTypePart::Tuple(num_members as u32)); - for _ in 0..num_members { - initial_type.push(InferenceTypePart::Unknown); - } - let mut progress_expr = self.apply_template_constraint(ctx, upcast_id, &initial_type)?; - - // The elements of the tuple can have any type, but they must - // end up as arguments to the output tuple type. - debug_log!(" * During (checking expressions constituting tuple):"); - for (member_expr_index, member_expr_id) in expr_elements.iter_copied().enumerate() { - // For the current expression index, (re)compute the - // position in the tuple type where the types should match. - let mut start_index = 1; // first element is Tuple type, second is the first child - for _ in 0..member_expr_index { - let tuple_expr_index = ctx.heap[id].unique_id_in_definition; - let tuple_type = &self.expr_types[tuple_expr_index as usize].expr_type; - start_index = InferenceType::find_subtree_end_idx(&tuple_type.parts, start_index); - debug_assert_ne!(start_index, tuple_type.parts.len()); // would imply less tuple type children than member expressions - } + // Check if we need to apply the initial tuple template type. Note that + // this is a hacky check. + let num_tuple_elements = rule.element_indices.len(); + let mut template_type = Vec::with_capacity(num_tuple_elements + 1); // TODO: @performance + template_type.push(InferenceTypePart::Tuple(num_tuple_elements as u32)); + for _ in 0..num_tuple_elements { + template_type.push(InferenceTypePart::Unknown); + } - // Apply the constraint - let (member_progress_expr, member_progress) = self.apply_equal2_constraint( - ctx, upcast_id, upcast_id, start_index, member_expr_id, 0 - )?; - debug_log!(" - Member {} type | {}", member_expr_index, self.debug_get_display_name(ctx, *member_expr_id)); - progress_expr = progress_expr || member_progress_expr; + let mut progress_literal = self.apply_template_constraint(ctx, node_index, &template_type)?; - if member_progress { - self.queue_expr(ctx, member_expr_id); - } - } + // Because of the (early returning error) check above, we're certain + // that the tuple has the correct number of elements. Now match each + // element expression type to the tuple subtype. + let mut element_subtree_start_index = 1; // first element is InferenceTypePart::Tuple + for element_node_index in element_indices.iter_copied() { + let (progress_literal_element, progress_element) = self.apply_equal2_constraint( + ctx, node_index, node_index, element_subtree_start_index, element_node_index, 0 + )?; - expr_elements.forget(); - progress_expr + progress_literal = progress_literal || progress_literal_element; + if progress_element { + self.queue_node(element_node_index); } - }; - debug_log!(" * After:"); - debug_log!(" - Expr type [{}]: {}", progress_expr, self.debug_get_display_name(ctx, upcast_id)); + // Prepare for next element + let node = &self.infer_nodes[node_index]; + let subtree_end_index = InferenceType::find_subtree_end_idx(&node.expr_type.parts, element_subtree_start_index); + element_subtree_start_index = subtree_end_index; + } + debug_assert_eq!(element_subtree_start_index, self.infer_nodes[node_index].expr_type.parts.len()); - if progress_expr { self.queue_expr_parent(ctx, upcast_id); } + if progress_literal { self.queue_node_parent(node_index); } - Ok(()) + element_indices.forget(); + return Ok(()); } - fn progress_cast_expr(&mut self, ctx: &mut Ctx, id: CastExpressionId) -> Result<(), ParseError> { - let upcast_id = id.upcast(); - let expr = &ctx.heap[id]; - let expr_idx = expr.unique_id_in_definition; - - debug_log!("Casting expr: {}", upcast_id.index); - debug_log!(" * Before:"); - debug_log!(" - Expr type: {}", self.debug_get_display_name(ctx, upcast_id)); - debug_log!(" - Subject type: {}", self.debug_get_display_name(ctx, expr.subject)); - - // The cast expression might have its output type fixed by the - // programmer, so apply that type to the output. Apart from that casting - // acts like a blocker for two-way inference. So we'll just have to wait - // until we know if the cast is valid. - // TODO: Another thing that has to be updated the moment the type - // inferencer is fully index/job-based - let infer_type = self.determine_inference_type_from_parser_type_elements(&expr.to_type.elements, true); - let expr_progress = self.apply_template_constraint(ctx, upcast_id, &infer_type.parts)?; - - if expr_progress { - self.queue_expr_parent(ctx, upcast_id); - } - - // Check if the two types are compatible - debug_log!(" * After:"); - debug_log!(" - Expr type [{}]: {}", expr_progress, self.debug_get_display_name(ctx, upcast_id)); - debug_log!(" - Note that the subject type can never be inferred"); - debug_log!(" * Decision:"); - - let subject_idx = ctx.heap[expr.subject].get_unique_id_in_definition(); - let expr_type = &self.expr_types[expr_idx as usize].expr_type; - let subject_type = &self.expr_types[subject_idx as usize].expr_type; - if !expr_type.is_done || !subject_type.is_done { - // Not yet done - debug_log!(" - Casting is valid: unknown as the types are not yet complete"); - return Ok(()) + fn progress_inference_rule_cast_expr(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let rule = node.inference_rule.as_cast_expr(); + let subject_index = rule.subject_index; + let subject = &self.infer_nodes[subject_index]; + + // Make sure that both types are completely done. Note: a cast + // expression cannot really infer anything between the subject and the + // output type, we can only make sure that, at the end, the cast is + // correct. + if !node.expr_type.is_done || !subject.expr_type.is_done { + return Ok(()); } - // Valid casts: (bool, integer, character) can always be cast to one - // another. A cast from a type to itself is also valid. + // Both types are known, currently the only valid casts are bool, + // integer and character casts. fn is_bool_int_or_char(parts: &[InferenceTypePart]) -> bool { - return parts.len() == 1 && ( - parts[0] == InferenceTypePart::Bool || - parts[0] == InferenceTypePart::Character || - parts[0].is_concrete_integer() - ); + let mut index = 0; + while index < parts.len() { + let part = &parts[index]; + if !part.is_marker() { break; } + index += 1; + } + + debug_assert!(index != parts.len()); + let part = &parts[index]; + if *part == InferenceTypePart::Bool || *part == InferenceTypePart::Character || part.is_concrete_integer() { + debug_assert!(index + 1 == parts.len()); // type is done, first part does not have children -> must be at end + return true; + } else { + return false; + } } - let is_valid = if is_bool_int_or_char(&expr_type.parts) && is_bool_int_or_char(&subject_type.parts) { + let is_valid = if is_bool_int_or_char(&node.expr_type.parts) && is_bool_int_or_char(&subject.expr_type.parts) { true - } else if expr_type.parts == subject_type.parts { + } else if InferenceType::check_subtrees(&node.expr_type.parts, 0, &subject.expr_type.parts, 0) { + // again: check_subtrees is sufficient since both types are done true } else { false }; - debug_log!(" - Casting is valid: {}", is_valid); - if !is_valid { - let cast_expr = &ctx.heap[id]; - let subject_expr = &ctx.heap[cast_expr.subject]; + let cast_expr = &ctx.heap[node.expr_id]; + let subject_expr = &ctx.heap[subject.expr_id]; return Err(ParseError::new_error_str_at_span( - &ctx.module().source, cast_expr.full_span, "invalid casting operation" + &ctx.module().source, cast_expr.full_span(), "invalid casting operation" ).with_info_at_span( &ctx.module().source, subject_expr.full_span(), format!( - "cannot cast the argument type '{}' to the cast type '{}'", - subject_type.display_name(&ctx.heap), - expr_type.display_name(&ctx.heap) + "cannot cast the argument type '{}' to the type '{}'", + subject.expr_type.display_name(&ctx.heap), + node.expr_type.display_name(&ctx.heap) ) )); } - Ok(()) + return Ok(()) } - // TODO: @cleanup, see how this can be cleaned up once I implement - // polymorphic struct/enum/union literals. These likely follow the same - // pattern as here. - fn progress_call_expr(&mut self, ctx: &mut Ctx, id: CallExpressionId) -> Result<(), ParseError> { - let upcast_id = id.upcast(); - let expr = &ctx.heap[id]; - let expr_idx = expr.unique_id_in_definition; - let extra_idx = self.expr_types[expr_idx as usize].extra_data_idx; - - debug_log!("Call expr '{}': {}", ctx.heap[expr.definition].identifier().value.as_str(), upcast_id.index); - debug_log!(" * Before:"); - debug_log!(" - Expr type: {}", self.debug_get_display_name(ctx, upcast_id)); - debug_log!(" * During (inferring types from arguments and return type):"); - - let extra = &mut self.extra_data[extra_idx as usize]; - - // Check if we can make progress using the arguments and/or return types - // while keeping track of the polyvars we've extended - let mut poly_progress = HashSet::new(); - debug_assert_eq!(extra.embedded.len(), expr.arguments.len()); - - for (call_arg_idx, arg_id) in expr.arguments.clone().into_iter().enumerate() { - let arg_expr_idx = ctx.heap[arg_id].get_unique_id_in_definition(); - let signature_type: *mut _ = &mut extra.embedded[call_arg_idx]; - let argument_type: *mut _ = &mut self.expr_types[arg_expr_idx as usize].expr_type; - let (_, progress_arg) = Self::apply_equal2_signature_constraint( - ctx, upcast_id, Some(arg_id), extra, &mut poly_progress, - signature_type, 0, argument_type, 0 + fn progress_inference_rule_call_expr(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let node_expr_id = node.expr_id; + let rule = node.inference_rule.as_call_expr(); + + let mut poly_progress_section = self.poly_progress_buffer.start_section(); + let argument_node_indices = self.index_buffer.start_section_initialized(&rule.argument_indices); + + // Perform inference on arguments to function, while trying to figure + // out the polymorphic variables + for (argument_index, argument_node_index) in argument_node_indices.iter_copied().enumerate() { + let argument_expr_id = self.infer_nodes[argument_node_index].expr_id; + let (_, progress_argument) = self.apply_polydata_equal2_constraint( + ctx, node_index, argument_expr_id, "argument's", + PolyDataTypeIndex::Associated(argument_index), 0, + argument_node_index, 0, &mut poly_progress_section )?; - debug_log!( - " - Arg {} type | sig: {}, arg: {}", call_arg_idx, - unsafe{&*signature_type}.display_name(&ctx.heap), - unsafe{&*argument_type}.display_name(&ctx.heap)); - - if progress_arg { - // Progressed argument expression - self.expr_queued.push_back(arg_expr_idx); - } + if progress_argument { self.queue_node(argument_node_index); } } - // Do the same for the return type - let signature_type: *mut _ = &mut extra.returned; - let expr_type: *mut _ = &mut self.expr_types[expr_idx as usize].expr_type; - let (_, progress_expr) = Self::apply_equal2_signature_constraint( - ctx, upcast_id, None, extra, &mut poly_progress, - signature_type, 0, expr_type, 0 + // Same for the return type. + let (_, progress_call_1) = self.apply_polydata_equal2_constraint( + ctx, node_index, node_expr_id, "return", + PolyDataTypeIndex::Returned, 0, + node_index, 0, &mut poly_progress_section )?; - debug_log!( - " - Ret type | sig: {}, expr: {}", - unsafe{&*signature_type}.display_name(&ctx.heap), - unsafe{&*expr_type}.display_name(&ctx.heap) - ); - - if progress_expr { - // TODO: @cleanup, cannot call utility self.queue_parent thingo - if let Some(parent_id) = ctx.heap[upcast_id].parent_expr_id() { - let parent_idx = ctx.heap[parent_id].get_unique_id_in_definition(); - self.expr_queued.push_back(parent_idx); - } - } - - // If we did not have an error in the polymorph inference above, then - // reapplying the polymorph type to each argument type and the return - // type should always succeed. - debug_log!(" * During (reinferring from progressed polyvars):"); - for (_poly_idx, _poly_var) in extra.poly_vars.iter().enumerate() { - debug_log!(" - Poly {} | sig: {}", _poly_idx, _poly_var.display_name(&ctx.heap)); - } - // TODO: @performance If the algorithm is changed to be more "on demand - // argument re-evaluation", instead of "all-argument re-evaluation", - // then this is no longer true - for arg_idx in 0..extra.embedded.len() { - let signature_type: *mut _ = &mut extra.embedded[arg_idx]; - let arg_expr_id = expr.arguments[arg_idx]; - let arg_expr_idx = ctx.heap[arg_expr_id].get_unique_id_in_definition(); - let arg_type: *mut _ = &mut self.expr_types[arg_expr_idx as usize].expr_type; - - let progress_arg = Self::apply_equal2_polyvar_constraint( - extra, &poly_progress, - signature_type, arg_type - ); - - debug_log!( - " - Arg {} type | sig: {}, arg: {}", arg_idx, - unsafe{&*signature_type}.display_name(&ctx.heap), - unsafe{&*arg_type}.display_name(&ctx.heap) + // We will now apply any progression in the polymorphic variable type + // back to the arguments. + for (argument_index, argument_node_index) in argument_node_indices.iter_copied().enumerate() { + let progress_argument = self.apply_polydata_polyvar_constraint( + ctx, node_index, PolyDataTypeIndex::Associated(argument_index), + argument_node_index, &poly_progress_section ); - if progress_arg { - self.expr_queued.push_back(arg_expr_idx); - } - } - // Once more for the return type - let signature_type: *mut _ = &mut extra.returned; - let ret_type: *mut _ = &mut self.expr_types[expr_idx as usize].expr_type; + if progress_argument { self.queue_node(argument_node_index); } + } - let progress_ret = Self::apply_equal2_polyvar_constraint( - extra, &poly_progress, signature_type, ret_type + // And back to the return type. + let progress_call_2 = self.apply_polydata_polyvar_constraint( + ctx, node_index, PolyDataTypeIndex::Returned, + node_index, &poly_progress_section ); - debug_log!( - " - Ret type | sig: {}, arg: {}", - unsafe{&*signature_type}.display_name(&ctx.heap), - unsafe{&*ret_type}.display_name(&ctx.heap) - ); - if progress_ret { - self.queue_expr_parent(ctx, upcast_id); - } - debug_log!(" * After:"); - debug_log!(" - Expr type: {}", self.debug_get_display_name(ctx, upcast_id)); + if progress_call_1 || progress_call_2 { self.queue_node_parent(node_index); } - Ok(()) + poly_progress_section.forget(); + argument_node_indices.forget(); + + self.finish_polydata_constraint(node_index); + return Ok(()) } - fn progress_variable_expr(&mut self, ctx: &mut Ctx, id: VariableExpressionId) -> Result<(), ParseError> { - let upcast_id = id.upcast(); - let var_expr = &ctx.heap[id]; - let var_expr_idx = var_expr.unique_id_in_definition; - let var_id = var_expr.declaration.unwrap(); + fn progress_inference_rule_variable_expr(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &mut self.infer_nodes[node_index]; + let rule = node.inference_rule.as_variable_expr(); + let var_data_index = rule.var_data_index; - debug_log!("Variable expr '{}': {}", ctx.heap[var_id].identifier.value.as_str(), upcast_id.index); - debug_log!(" * Before:"); - debug_log!(" - Var type: {}", self.var_types.get(&var_id).unwrap().var_type.display_name(&ctx.heap)); - debug_log!(" - Expr type: {}", self.debug_get_display_name(ctx, upcast_id)); + let var_data = &mut self.var_data[var_data_index]; + // Apply inference to the shared variable type and the expression type + let shared_type: *mut _ = &mut var_data.var_type; + let expr_type: *mut _ = &mut node.expr_type; - // Retrieve shared variable type and expression type and apply inference - let var_data = self.var_types.get_mut(&var_id).unwrap(); - let expr_type = &mut self.expr_types[var_expr_idx as usize].expr_type; + let inference_result = unsafe { + // safety: vectors exist in different storage vectors, so cannot alias + InferenceType::infer_subtrees_for_both_types(shared_type, 0, expr_type, 0) + }; - let infer_res = unsafe{ InferenceType::infer_subtrees_for_both_types( - &mut var_data.var_type as *mut _, 0, expr_type, 0 - ) }; - if infer_res == DualInferenceResult::Incompatible { - let var_decl = &ctx.heap[var_id]; - return Err(ParseError::new_error_at_span( - &ctx.module().source, var_decl.identifier.span, format!( - "Conflicting types for this variable, previously assigned the type '{}'", - var_data.var_type.display_name(&ctx.heap) - ) - ).with_info_at_span( - &ctx.module().source, var_expr.identifier.span, format!( - "But inferred to have incompatible type '{}' here", - expr_type.display_name(&ctx.heap) - ) - )) + if inference_result == DualInferenceResult::Incompatible { + return Err(self.construct_variable_type_error(ctx, node_index)); } - let progress_var = infer_res.modified_lhs(); - let progress_expr = infer_res.modified_rhs(); + let progress_var_data = inference_result.modified_lhs(); + let progress_expr = inference_result.modified_rhs(); - if progress_var { - // Let other variable expressions using this type progress as well - for other_expr in var_data.used_at.iter() { - if *other_expr != upcast_id { - let other_expr_idx = ctx.heap[*other_expr].get_unique_id_in_definition(); - self.expr_queued.push_back(other_expr_idx); + if progress_var_data { + // We progressed the type of the shared variable, so propagate this + // to all associated variable expressions (and relatived variables). + for other_node_index in var_data.used_at.iter().copied() { + if other_node_index != node_index { + self.node_queued.push_back(other_node_index); } } - // Let a linked port know that our type has updated - if let Some(linked_id) = var_data.linked_var { - // Only perform one-way inference to prevent updating our type, - // this would lead to an inconsistency in the type inference - // algorithm otherwise. - let var_type: *mut _ = &mut var_data.var_type; - let link_data = self.var_types.get_mut(&linked_id).unwrap(); - + if let Some(linked_var_data_index) = var_data.linked_var { + // Only perform one-way inference, progressing the linked + // variable. + // note: because this "linking" is used only for channels, we + // will start inference one level below the top-level in the + // type tree (i.e. ensure `T` in `in` and `out` is equal). debug_assert!( - unsafe{&*var_type}.parts[0] == InferenceTypePart::Input || - unsafe{&*var_type}.parts[0] == InferenceTypePart::Output + var_data.var_type.parts[0] == InferenceTypePart::Input || + var_data.var_type.parts[0] == InferenceTypePart::Output ); + let this_var_type: *const _ = &var_data.var_type; + let linked_var_data = &mut self.var_data[linked_var_data_index]; debug_assert!( - link_data.var_type.parts[0] == InferenceTypePart::Input || - link_data.var_type.parts[0] == InferenceTypePart::Output + linked_var_data.var_type.parts[0] == InferenceTypePart::Input || + linked_var_data.var_type.parts[0] == InferenceTypePart::Output + ); + + // safety: by construction var_data_index and linked_var_data_index cannot be the + // same, hence we're not aliasing here. + let inference_result = InferenceType::infer_subtree_for_single_type( + &mut linked_var_data.var_type, 1, + unsafe{ &(*this_var_type).parts }, 1, false ); - match InferenceType::infer_subtree_for_single_type(&mut link_data.var_type, 1, &unsafe{&*var_type}.parts, 1, false) { + match inference_result { SingleInferenceResult::Modified => { - for other_expr in &link_data.used_at { - let other_expr_idx = ctx.heap[*other_expr].get_unique_id_in_definition(); - self.expr_queued.push_back(other_expr_idx); + for used_at in linked_var_data.used_at.iter().copied() { + self.node_queued.push_back(used_at); } }, SingleInferenceResult::Unmodified => {}, SingleInferenceResult::Incompatible => { - let var_data = self.var_types.get(&var_id).unwrap(); - let link_data = self.var_types.get(&linked_id).unwrap(); - let var_decl = &ctx.heap[var_id]; - let link_decl = &ctx.heap[linked_id]; + let var_data_this = &self.var_data[var_data_index]; + let var_decl_this = &ctx.heap[var_data_this.var_id]; + let var_data_linked = &self.var_data[linked_var_data_index]; + let var_decl_linked = &ctx.heap[var_data_linked.var_id]; return Err(ParseError::new_error_at_span( - &ctx.module().source, var_decl.identifier.span, format!( - "Conflicting types for this variable, assigned the type '{}'", - var_data.var_type.display_name(&ctx.heap) + &ctx.module().source, var_decl_this.identifier.span, format!( + "conflicting types for this channel, this port has type '{}'", + var_data_this.var_type.display_name(&ctx.heap) ) ).with_info_at_span( - &ctx.module().source, link_decl.identifier.span, format!( - "Because it is incompatible with this variable, assigned the type '{}'", - link_data.var_type.display_name(&ctx.heap) + &ctx.module().source, var_decl_linked.identifier.span, format!( + "while this port has type '{}'", + var_data_linked.var_type.display_name(&ctx.heap) ) )); } } } } - if progress_expr { self.queue_expr_parent(ctx, upcast_id); } - debug_log!(" * After:"); - debug_log!(" - Var type [{}]: {}", progress_var, self.var_types.get(&var_id).unwrap().var_type.display_name(&ctx.heap)); - debug_log!(" - Expr type [{}]: {}", progress_expr, self.debug_get_display_name(ctx, upcast_id)); + if progress_expr { self.queue_node_parent(node_index); } + + return Ok(()); + } + fn progress_template(&mut self, ctx: &Ctx, node_index: InferNodeIndex, application: InferenceRuleTemplateApplication, template: &[InferenceTypePart]) -> Result { + use InferenceRuleTemplateApplication as TA; - Ok(()) + match application { + TA::None => Ok(false), + TA::Template => self.apply_template_constraint(ctx, node_index, template), + TA::Forced => self.apply_forced_constraint(ctx, node_index, template), + } } - fn queue_expr_parent(&mut self, ctx: &Ctx, expr_id: ExpressionId) { - if let ExpressionParent::Expression(parent_expr_id, _) = &ctx.heap[expr_id].parent() { - let expr_idx = ctx.heap[*parent_expr_id].get_unique_id_in_definition(); - self.expr_queued.push_back(expr_idx); + fn queue_node_parent(&mut self, node_index: InferNodeIndex) { + let node = &self.infer_nodes[node_index]; + if let Some(parent_node_index) = node.parent_index { + self.node_queued.push_back(parent_node_index); } } - fn queue_expr(&mut self, ctx: &Ctx, expr_id: ExpressionId) { - let expr_idx = ctx.heap[expr_id].get_unique_id_in_definition(); - self.expr_queued.push_back(expr_idx); + #[inline] + fn queue_node(&mut self, node_index: InferNodeIndex) { + self.node_queued.push_back(node_index); } + /// Returns whether the type is certainly a string (true, false), certainly + /// not a string (false, true), or still unknown (false, false). + fn type_is_certainly_or_certainly_not_string(&self, node_index: InferNodeIndex) -> (bool, bool) { + let expr_type = &self.infer_nodes[node_index].expr_type; + let mut part_index = 0; + while part_index < expr_type.parts.len() { + let part = &expr_type.parts[part_index]; + + if part.is_marker() { + part_index += 1; + continue; + } + if !part.is_concrete() { break; } - // first returned is certainly string, second is certainly not - fn type_is_certainly_or_certainly_not_string(&self, ctx: &Ctx, expr_id: ExpressionId) -> (bool, bool) { - let expr_idx = ctx.heap[expr_id].get_unique_id_in_definition(); - let expr_type = &self.expr_types[expr_idx as usize].expr_type; - if expr_type.is_done { - if expr_type.parts[0] == InferenceTypePart::String { + if *part == InferenceTypePart::String { + // First part is a string return (true, false); } else { return (false, true); } } + // If here then first non-marker type is not concrete + if part_index == expr_type.parts.len() { + // nothing known at all + return (false, false); + } + + // Special case: array-like where its argument is not a character + if part_index + 1 < expr_type.parts.len() { + if expr_type.parts[part_index] == InferenceTypePart::ArrayLike && expr_type.parts[part_index + 1] != InferenceTypePart::Character { + return (false, true); + } + } + + (false, false) } @@ -2967,45 +3026,30 @@ impl PassTyping { /// expression type as well. Hence the template may be fully specified (e.g. /// a bool) or contain "inference" variables (e.g. an array of T) fn apply_template_constraint( - &mut self, ctx: &Ctx, expr_id: ExpressionId, template: &[InferenceTypePart] + &mut self, ctx: &Ctx, node_index: InferNodeIndex, template: &[InferenceTypePart] ) -> Result { - let expr_idx = ctx.heap[expr_id].get_unique_id_in_definition(); // TODO: @Temp - let expr_type = &mut self.expr_types[expr_idx as usize].expr_type; + let expr_type = &mut self.infer_nodes[node_index].expr_type; match InferenceType::infer_subtree_for_single_type(expr_type, 0, template, 0, false) { SingleInferenceResult::Modified => Ok(true), SingleInferenceResult::Unmodified => Ok(false), SingleInferenceResult::Incompatible => Err( - self.construct_template_type_error(ctx, expr_id, template) + self.construct_template_type_error(ctx, node_index, template) ) } } - fn apply_template_constraint_to_types( - to_infer: *mut InferenceType, to_infer_start_idx: usize, - template: &[InferenceTypePart], template_start_idx: usize - ) -> Result { - match InferenceType::infer_subtree_for_single_type( - unsafe{ &mut *to_infer }, to_infer_start_idx, - template, template_start_idx, false - ) { - SingleInferenceResult::Modified => Ok(true), - SingleInferenceResult::Unmodified => Ok(false), - SingleInferenceResult::Incompatible => Err(()), - } - } - /// Applies a forced constraint: the supplied expression's type MUST be /// inferred from the template, the other way around is considered invalid. fn apply_forced_constraint( - &mut self, ctx: &Ctx, expr_id: ExpressionId, template: &[InferenceTypePart] + &mut self, ctx: &Ctx, node_index: InferNodeIndex, template: &[InferenceTypePart] ) -> Result { - let expr_idx = ctx.heap[expr_id].get_unique_id_in_definition(); - let expr_type = &mut self.expr_types[expr_idx as usize].expr_type; + let expr_type = &mut self.infer_nodes[node_index].expr_type; + match InferenceType::infer_subtree_for_single_type(expr_type, 0, template, 0, true) { SingleInferenceResult::Modified => Ok(true), SingleInferenceResult::Unmodified => Ok(false), SingleInferenceResult::Incompatible => Err( - self.construct_template_type_error(ctx, expr_id, template) + self.construct_template_type_error(ctx, node_index, template) ) } } @@ -3015,184 +3059,228 @@ impl PassTyping { /// is successful then the composition of all types are made equal. /// The "parent" `expr_id` is provided to construct errors. fn apply_equal2_constraint( - &mut self, ctx: &Ctx, expr_id: ExpressionId, - arg1_id: ExpressionId, arg1_start_idx: usize, - arg2_id: ExpressionId, arg2_start_idx: usize + &mut self, ctx: &Ctx, node_index: InferNodeIndex, + arg1_index: InferNodeIndex, arg1_start_idx: usize, + arg2_index: InferNodeIndex, arg2_start_idx: usize ) -> Result<(bool, bool), ParseError> { - let arg1_expr_idx = ctx.heap[arg1_id].get_unique_id_in_definition(); // TODO: @Temp - let arg2_expr_idx = ctx.heap[arg2_id].get_unique_id_in_definition(); - let arg1_type: *mut _ = &mut self.expr_types[arg1_expr_idx as usize].expr_type; - let arg2_type: *mut _ = &mut self.expr_types[arg2_expr_idx as usize].expr_type; + let arg1_type: *mut _ = &mut self.infer_nodes[arg1_index].expr_type; + let arg2_type: *mut _ = &mut self.infer_nodes[arg2_index].expr_type; let infer_res = unsafe{ InferenceType::infer_subtrees_for_both_types( arg1_type, arg1_start_idx, arg2_type, arg2_start_idx ) }; if infer_res == DualInferenceResult::Incompatible { - return Err(self.construct_arg_type_error(ctx, expr_id, arg1_id, arg2_id)); + return Err(self.construct_arg_type_error(ctx, node_index, arg1_index, arg2_index)); } Ok((infer_res.modified_lhs(), infer_res.modified_rhs())) } - /// Applies an equal2 constraint between a signature type (e.g. a function - /// argument or struct field) and an expression whose type should match that - /// expression. If we make progress on the signature, then we try to see if - /// any of the embedded polymorphic types can be progressed. + /// Applies an equal2 constraint between a member of the `PolyData` struct, + /// and another inferred type. If any progress is made in the `PolyData` + /// struct then the affected polymorphic variables are updated as well. + /// + /// Because a lot of types/expressions are involved in polymorphic typFe + /// inference, some explanation: "outer_node" refers to the main expression + /// that is the root cause of type inference (e.g. a struct literal + /// expression, or a tuple member select expression). Associated with that + /// outer node is `PolyData`, so that is what the "poly_data" variables + /// are referring to. We are applying equality between a "poly_data" type + /// and an associated expression (not necessarily the "outer_node", e.g. + /// the expression that constructs the value of a struct field). Hence the + /// "associated" variables. /// - /// `outer_expr_id` is the main expression we're progressing (e.g. a - /// function call), while `expr_id` is the embedded expression we're - /// matching against the signature. `expression_type` and - /// `expression_start_idx` belong to `expr_id`. - fn apply_equal2_signature_constraint( - ctx: &Ctx, outer_expr_id: ExpressionId, expr_id: Option, - polymorph_data: &mut ExtraData, polymorph_progress: &mut HashSet, - signature_type: *mut InferenceType, signature_start_idx: usize, - expression_type: *mut InferenceType, expression_start_idx: usize + /// Finally, when an error occurs we'll first show the outer node's + /// location. As info, the `error_location_expr_id` span is shown, + /// indicating that the "`error_type_name` type has been resolved to + /// `outer_node_type`, but this expression has been resolved to + /// `associated_node_type`". + fn apply_polydata_equal2_constraint( + &mut self, ctx: &Ctx, + outer_node_index: InferNodeIndex, error_location_expr_id: ExpressionId, error_type_name: &str, + poly_data_type_index: PolyDataTypeIndex, poly_data_start_index: usize, + associated_node_index: InferNodeIndex, associated_node_start_index: usize, + poly_progress_section: &mut ScopedSection, ) -> Result<(bool, bool), ParseError> { - // Safety: all pointers distinct - - // Infer the signature and expression type - let infer_res = unsafe { + let poly_data_index = self.infer_nodes[outer_node_index].poly_data_index; + let poly_data = &mut self.poly_data[poly_data_index as usize]; + let poly_data_type = poly_data.expr_types.get_type_mut(poly_data_type_index); + let associated_type: *mut _ = &mut self.infer_nodes[associated_node_index].expr_type; + + let inference_result = unsafe{ + // Safety: pointers originate from different vectors, so cannot + // alias. + let poly_data_type: *mut _ = poly_data_type; InferenceType::infer_subtrees_for_both_types( - signature_type, signature_start_idx, - expression_type, expression_start_idx - ) + poly_data_type, poly_data_start_index, + associated_type, associated_node_start_index + ) }; - if infer_res == DualInferenceResult::Incompatible { - // TODO: Check if I still need to use this - let outer_span = ctx.heap[outer_expr_id].full_span(); - let (span_name, span) = match expr_id { - Some(expr_id) => ("argument's", ctx.heap[expr_id].full_span()), - None => ("type's", outer_span) - }; - let (signature_display_type, expression_display_type) = unsafe { ( - (&*signature_type).display_name(&ctx.heap), - (&*expression_type).display_name(&ctx.heap) - ) }; + let modified_poly_data = inference_result.modified_lhs(); + let modified_associated = inference_result.modified_rhs(); + if inference_result == DualInferenceResult::Incompatible { + let outer_node_expr_id = self.infer_nodes[outer_node_index].expr_id; + let outer_node_span = ctx.heap[outer_node_expr_id].full_span(); + let detailed_span = ctx.heap[error_location_expr_id].full_span(); + let outer_node_type = poly_data_type.display_name(&ctx.heap); + let associated_type = self.infer_nodes[associated_node_index].expr_type.display_name(&ctx.heap); + + let source = &ctx.module().source; return Err(ParseError::new_error_str_at_span( - &ctx.module().source, outer_span, - "failed to fully resolve the types of this expression" - ).with_info_at_span( - &ctx.module().source, span, format!( - "because the {} signature has been resolved to '{}', but the expression has been resolved to '{}'", - span_name, signature_display_type, expression_display_type + source, outer_node_span, "failed to resolve the types of this expression" + ).with_info_str_at_span( + source, detailed_span, &format!( + "because the {} type has been resolved to '{}', but this expression has been resolved to '{}'", + error_type_name, outer_node_type, associated_type ) )); } - // Try to see if we can progress any of the polymorphic variables - let progress_sig = infer_res.modified_lhs(); - let progress_expr = infer_res.modified_rhs(); + if modified_poly_data { + debug_assert!(poly_data_type.has_marker); - if progress_sig { - let signature_type = unsafe{&mut *signature_type}; - debug_assert!( - signature_type.has_marker, - "made progress on signature type, but it doesn't have a marker" - ); - for (poly_idx, poly_section) in signature_type.marker_iter() { - let polymorph_type = &mut polymorph_data.poly_vars[poly_idx as usize]; - match Self::apply_template_constraint_to_types( - polymorph_type, 0, poly_section, 0 - ) { - Ok(true) => { polymorph_progress.insert(poly_idx); }, - Ok(false) => {}, - Err(()) => { return Err(Self::construct_poly_arg_error(ctx, polymorph_data, outer_expr_id))} + // Go through markers for polymorphic variables and use the + // (hopefully) more specific types to update their representation + // in the PolyData struct + for (poly_var_index, poly_var_section) in poly_data_type.marker_iter() { + let poly_var_type = &mut poly_data.poly_vars[poly_var_index as usize]; + match InferenceType::infer_subtree_for_single_type(poly_var_type, 0, poly_var_section, 0, false) { + SingleInferenceResult::Modified => { + poly_progress_section.push_unique(poly_var_index); + }, + SingleInferenceResult::Unmodified => { + // nothing to do + }, + SingleInferenceResult::Incompatible => { + return Err(Self::construct_poly_arg_error( + ctx, &self.poly_data[poly_data_index as usize], + self.infer_nodes[outer_node_index].expr_id + )); + } } } } - Ok((progress_sig, progress_expr)) + + return Ok((modified_poly_data, modified_associated)); } - /// Applies equal2 constraints on the signature type for each of the - /// polymorphic variables. If the signature type is progressed then we - /// progress the expression type as well. + /// After calling `apply_polydata_equal2_constraint` on several expressions + /// that are associated with some kind of polymorphic expression, several of + /// the polymorphic variables might have been inferred to more specific + /// types than before. /// - /// This function assumes that the polymorphic variables have already been - /// progressed as far as possible by calling - /// `apply_equal2_signature_constraint`. As such, we expect to not encounter - /// any errors. + /// At this point one should call this function to apply the progress in + /// these polymorphic variables back onto the types that are functions of + /// these polymorphic variables. /// - /// This function returns true if the expression's type has been progressed - fn apply_equal2_polyvar_constraint( - polymorph_data: &ExtraData, _polymorph_progress: &HashSet, - signature_type: *mut InferenceType, expr_type: *mut InferenceType + /// An example: a struct literal with a polymorphic variable `T` may have + /// two fields `foo` and `bar` each with different types that are a function + /// of the polymorhic variable `T`. If the expressions constructing the + /// value for the field `foo` causes the type `T` to progress, then we can + /// also progress the type of the expression that constructs `bar`. + /// + /// And so we have `outer_node_index` + `poly_data_type_index` pointing to + /// the appropriate type in the `PolyData` struct. Which will be updated + /// first using the polymorphic variables. If we happen to have updated that + /// type, then we should also progress the associated expression, hence the + /// `associated_node_index`. + fn apply_polydata_polyvar_constraint( + &mut self, _ctx: &Ctx, + outer_node_index: InferNodeIndex, poly_data_type_index: PolyDataTypeIndex, + associated_node_index: InferNodeIndex, poly_progress_section: &ScopedSection ) -> bool { - // Safety: all pointers should be distinct - // polymorph_data containers may not be modified - let signature_type = unsafe{&mut *signature_type}; - let expr_type = unsafe{&mut *expr_type}; - - // Iterate through markers in signature type to try and make progress - // on the polymorphic variable - let mut seek_idx = 0; - let mut modified_sig = false; - - while let Some((poly_idx, start_idx)) = signature_type.find_marker(seek_idx) { - let end_idx = InferenceType::find_subtree_end_idx(&signature_type.parts, start_idx); - // if polymorph_progress.contains(&poly_idx) { - // Need to match subtrees - let polymorph_type = &polymorph_data.poly_vars[poly_idx as usize]; - let modified_at_marker = Self::apply_template_constraint_to_types( - signature_type, start_idx, - &polymorph_type.parts, 0 - ).expect("no failure when applying polyvar constraints"); - - modified_sig = modified_sig || modified_at_marker; - // } - - seek_idx = end_idx; + let poly_data_index = self.infer_nodes[outer_node_index].poly_data_index; + let poly_data = &mut self.poly_data[poly_data_index as usize]; + + // Early exit, most common case (literals or functions calls which are + // actually not polymorphic) + if !poly_data.first_rule_application && poly_progress_section.len() == 0 { + return false; } - // If we made any progress on the signature's type, then we also need to - // apply it to the expression that is supposed to match the signature. - if modified_sig { + // safety: we're borrowing from two distinct fields, so should be fine + let poly_data_type = poly_data.expr_types.get_type_mut(poly_data_type_index); + let mut last_start_index = 0; + let mut modified_poly_type = false; + + while let Some((poly_var_index, poly_var_start_index)) = poly_data_type.find_marker(last_start_index) { + let poly_var_end_index = InferenceType::find_subtree_end_idx(&poly_data_type.parts, poly_var_start_index); + + if poly_data.first_rule_application || poly_progress_section.contains(&poly_var_index) { + // We have updated this polymorphic variable, so try updating it + // in the PolyData type + let modified_in_poly_data = match InferenceType::infer_subtree_for_single_type( + poly_data_type, poly_var_start_index, &poly_data.poly_vars[poly_var_index as usize].parts, 0, false + ) { + SingleInferenceResult::Modified => true, + SingleInferenceResult::Unmodified => false, + SingleInferenceResult::Incompatible => { + // practically impossible: before calling this function we gather all the + // data on the polymorphic variables from the associated expressions. So if + // the polymorphic variables in those expressions were not mutually + // compatible, we must have encountered that error already. + unreachable!() + }, + }; + + modified_poly_type = modified_poly_type || modified_in_poly_data; + } + + last_start_index = poly_var_end_index; + } + + if modified_poly_type { + let associated_type = &mut self.infer_nodes[associated_node_index].expr_type; match InferenceType::infer_subtree_for_single_type( - expr_type, 0, &signature_type.parts, 0, true + associated_type, 0, &poly_data_type.parts, 0, true ) { - SingleInferenceResult::Modified => true, - SingleInferenceResult::Unmodified => false, - SingleInferenceResult::Incompatible => - unreachable!("encountered failure while reapplying modified signature to expression after polyvar inference") + SingleInferenceResult::Modified => return true, + SingleInferenceResult::Unmodified => return false, + SingleInferenceResult::Incompatible => unreachable!(), // same as above } } else { - false + // Did not update associated type + return false; } } + /// Should be called after completing one full round of applying polydata + /// constraints. + fn finish_polydata_constraint(&mut self, outer_node_index: InferNodeIndex) { + let poly_data_index = self.infer_nodes[outer_node_index].poly_data_index; + let poly_data = &mut self.poly_data[poly_data_index as usize]; + poly_data.first_rule_application = false; + } + /// Applies a type constraint that expects all three provided types to be /// equal. In case we can make progress in inferring the types then we /// attempt to do so. If the call is successful then the composition of all /// types is made equal. fn apply_equal3_constraint( - &mut self, ctx: &Ctx, expr_id: ExpressionId, - arg1_id: ExpressionId, arg2_id: ExpressionId, + &mut self, ctx: &Ctx, node_index: InferNodeIndex, + arg1_index: InferNodeIndex, arg2_index: InferNodeIndex, start_idx: usize ) -> Result<(bool, bool, bool), ParseError> { - // Safety: all points are unique + // Safety: all indices are unique // containers may not be modified - let expr_expr_idx = ctx.heap[expr_id].get_unique_id_in_definition(); // TODO: @Temp - let arg1_expr_idx = ctx.heap[arg1_id].get_unique_id_in_definition(); - let arg2_expr_idx = ctx.heap[arg2_id].get_unique_id_in_definition(); - - let expr_type: *mut _ = &mut self.expr_types[expr_expr_idx as usize].expr_type; - let arg1_type: *mut _ = &mut self.expr_types[arg1_expr_idx as usize].expr_type; - let arg2_type: *mut _ = &mut self.expr_types[arg2_expr_idx as usize].expr_type; + let expr_type: *mut _ = &mut self.infer_nodes[node_index].expr_type; + let arg1_type: *mut _ = &mut self.infer_nodes[arg1_index].expr_type; + let arg2_type: *mut _ = &mut self.infer_nodes[arg2_index].expr_type; let expr_res = unsafe{ InferenceType::infer_subtrees_for_both_types(expr_type, start_idx, arg1_type, start_idx) }; if expr_res == DualInferenceResult::Incompatible { - return Err(self.construct_expr_type_error(ctx, expr_id, arg1_id)); + return Err(self.construct_expr_type_error(ctx, node_index, arg1_index)); } let args_res = unsafe{ InferenceType::infer_subtrees_for_both_types(arg1_type, start_idx, arg2_type, start_idx) }; if args_res == DualInferenceResult::Incompatible { - return Err(self.construct_arg_type_error(ctx, expr_id, arg1_id, arg2_id)); + return Err(self.construct_arg_type_error(ctx, node_index, arg1_index, arg2_index)); } // If all types are compatible, but the second call caused the arg1_type @@ -3216,15 +3304,15 @@ impl PassTyping { /// Applies equal constraint to N consecutive expressions. The returned /// `progress` vec will contain which expressions were progressed and will - /// have length N - // If you ever + /// have length N. fn apply_equal_n_constraint( - &mut self, ctx: &Ctx, expr_id: ExpressionId, - args: &ScopedSection, progress: &mut ScopedSection + &mut self, ctx: &Ctx, outer_node_index: InferNodeIndex, + arguments: &ScopedSection, progress: &mut ScopedSection ) -> Result<(), ParseError> { - // Early exit + // Depending on the argument perform an early exit. This simplifies + // later logic debug_assert_eq!(progress.len(), 0); - match args.len() { + match arguments.len() { 0 => { // nothing to progress return Ok(()) @@ -3241,65 +3329,62 @@ impl PassTyping { } } - // Do pairwise inference, keep track of the last entry we made progress - // on. Once done we need to update everything to the most-inferred type. - let mut arg_iter = args.iter_copied(); - let mut last_arg_id = arg_iter.next().unwrap(); - let mut last_lhs_progressed = 0; - let mut lhs_arg_idx = 0; - - while let Some(next_arg_id) = arg_iter.next() { - let last_expr_idx = ctx.heap[last_arg_id].get_unique_id_in_definition(); // TODO: @Temp - let next_expr_idx = ctx.heap[next_arg_id].get_unique_id_in_definition(); - let last_type: *mut _ = &mut self.expr_types[last_expr_idx as usize].expr_type; - let next_type: *mut _ = &mut self.expr_types[next_expr_idx as usize].expr_type; - - let res = unsafe { - InferenceType::infer_subtrees_for_both_types(last_type, 0, next_type, 0) - }; + // We'll start doing pairwise inference for all of the inference nodes + // (node[0] with node[1], then node[1] with node[2], then node[2] ..., + // etc.), so when we're at the end we have `node[N-1]` as the most + // progressed type. + let mut last_index_requiring_inference = 0; - if res == DualInferenceResult::Incompatible { - return Err(self.construct_arg_type_error(ctx, expr_id, last_arg_id, next_arg_id)); - } + for prev_argument_index in 0..arguments.len() - 1 { + let next_argument_index = prev_argument_index + 1; - if res.modified_lhs() { - // We re-inferred something on the left hand side, so everything - // up until now should be re-inferred. - progress[lhs_arg_idx] = true; - last_lhs_progressed = lhs_arg_idx; - } - progress[lhs_arg_idx + 1] = res.modified_rhs(); + let prev_node_index = arguments[prev_argument_index]; + let next_node_index = arguments[next_argument_index]; + let (prev_progress, next_progress) = self.apply_equal2_constraint( + ctx, outer_node_index, prev_node_index, 0, next_node_index, 0 + )?; - last_arg_id = next_arg_id; - lhs_arg_idx += 1; + if prev_progress { + // Previous node is progress, so every type in front of it needs + // to be reinferred. + progress[prev_argument_index] = true; + last_index_requiring_inference = prev_argument_index; + } + progress[next_argument_index] = next_progress; } - // Re-infer everything. Note that we do not need to re-infer the type - // exactly at `last_lhs_progressed`, but only everything up to it. - let last_arg_expr_idx = ctx.heap[last_arg_id].get_unique_id_in_definition(); - let last_type: *mut _ = &mut self.expr_types[last_arg_expr_idx as usize].expr_type; - for arg_idx in 0..last_lhs_progressed { - let other_arg_expr_idx = ctx.heap[args[arg_idx]].get_unique_id_in_definition(); - let arg_type: *mut _ = &mut self.expr_types[other_arg_expr_idx as usize].expr_type; - unsafe{ - (*arg_type).replace_subtree(0, &(*last_type).parts); + // Apply inference using the most progressed type (the last one) to the + // ones that did not obtain this information during the inference + // process. + let last_argument_node_index = arguments[arguments.len() - 1]; + let last_argument_type: *mut _ = &mut self.infer_nodes[last_argument_node_index].expr_type; + + for argument_index in 0..last_index_requiring_inference { + // We can cheat, we know the LHS is less specific than the right + // hand side, so: + let argument_node_index = arguments[argument_index]; + let argument_type = &mut self.infer_nodes[argument_node_index].expr_type; + unsafe { + // safety: we're dealing with different vectors, so cannot alias + argument_type.replace_subtree(0, &(*last_argument_type).parts); } - progress[arg_idx] = true; + progress[argument_index] = true; } return Ok(()); } /// Determines the `InferenceType` for the expression based on the - /// expression parent. Note that if the parent is another expression, we do - /// not take special action, instead we let parent expressions fix the type - /// of subexpressions before they have a chance to call this function. - fn insert_initial_expr_inference_type( + /// expression parent (this is not done if the parent is a regular 'ol + /// expression). Expects `parent_index` to be set to the parent of the + /// inference node that is created here. + fn insert_initial_inference_node( &mut self, ctx: &mut Ctx, expr_id: ExpressionId - ) -> Result<(), ParseError> { + ) -> Result { use ExpressionParent as EP; use InferenceTypePart as ITP; + // Set the initial inference type based on the expression parent. let expr = &ctx.heap[expr_id]; let inference_type = match expr.parent() { EP::None => @@ -3326,64 +3411,36 @@ impl PassTyping { EP::If(_) | EP::While(_) => // Must be a boolean InferenceType::new(false, true, vec![ITP::Bool]), - EP::Return(_) => + EP::Return(_) => { // Must match the return type of the function - if let DefinitionType::Function(func_id) = self.definition_type { - debug_assert_eq!(ctx.heap[func_id].return_types.len(), 1); - let returned = &ctx.heap[func_id].return_types[0]; - self.determine_inference_type_from_parser_type_elements(&returned.elements, true) - } else { - // Cannot happen: definition always set upon body traversal - // and "return" calls in components are illegal. - unreachable!(); - }, + debug_assert_eq!(self.procedure_kind, ProcedureKind::Function); + let returned = &ctx.heap[self.procedure_id].return_type.as_ref().unwrap(); + self.determine_inference_type_from_parser_type_elements(&returned.elements, true) + }, EP::New(_) => // Must be a component call, which we assign a "Void" return // type InferenceType::new(false, true, vec![ITP::Void]), }; - let infer_expr = &mut self.expr_types[expr.get_unique_id_in_definition() as usize]; - let needs_extra_data = match expr { - Expression::Call(_) => true, - Expression::Literal(expr) => match expr.value { - Literal::Enum(_) | Literal::Union(_) | Literal::Struct(_) => true, - _ => false, - }, - Expression::Select(expr) => match expr.kind { - SelectKind::StructField(_) => true, - SelectKind::TupleMember(_) => false, - }, - _ => false, - }; - - if infer_expr.expr_id.is_invalid() { - // Nothing is set yet - infer_expr.expr_type = inference_type; - infer_expr.expr_id = expr_id; - if needs_extra_data { - let extra_idx = self.extra_data.len() as i32; - self.extra_data.push(ExtraData::default()); - infer_expr.extra_data_idx = extra_idx; - } - } else { - // We already have an entry - debug_assert!(false, "does this ever happen?"); - if let SingleInferenceResult::Incompatible = InferenceType::infer_subtree_for_single_type( - &mut infer_expr.expr_type, 0, &inference_type.parts, 0, false - ) { - return Err(self.construct_expr_type_error(ctx, expr_id, expr_id)); - } - - debug_assert!((infer_expr.extra_data_idx != -1) == needs_extra_data); - } - - Ok(()) + let infer_index = self.infer_nodes.len() as InferNodeIndex; + self.infer_nodes.push(InferenceNode { + expr_type: inference_type, + expr_id, + inference_rule: InferenceRule::Noop, + parent_index: self.parent_index, + field_index: -1, + poly_data_index: -1, + info_type_id: TypeId::new_invalid(), + info_variant: ExpressionInfoVariant::Generic, + }); + + return Ok(infer_index); } fn insert_initial_call_polymorph_data( &mut self, ctx: &mut Ctx, call_id: CallExpressionId - ) { + ) -> PolyDataIndex { // Note: the polymorph variables may be partially specified and may // contain references to the wrapping definition's (i.e. the proctype // we are currently visiting) polymorphic arguments. @@ -3394,8 +3451,6 @@ impl PassTyping { // map them back and forth to the polymorphic arguments of the function // we are calling. let call = &ctx.heap[call_id]; - let extra_data_idx = self.expr_types[call.unique_id_in_definition as usize].extra_data_idx; // TODO: @Temp - debug_assert!(extra_data_idx != -1, "insert initial call polymorph data, no preallocated ExtraData"); // Handle the polymorphic arguments (if there are any) let num_poly_args = call.parser_type.elements[0].variant.num_embedded(); @@ -3405,55 +3460,46 @@ impl PassTyping { } // Handle the arguments and return types - let definition = &ctx.heap[call.definition]; - let (parameters, returned) = match definition { - Definition::Component(definition) => { - debug_assert_eq!(poly_args.len(), definition.poly_vars.len()); - (&definition.parameters, None) - }, - Definition::Function(definition) => { - debug_assert_eq!(poly_args.len(), definition.poly_vars.len()); - (&definition.parameters, Some(&definition.return_types)) - }, - Definition::Struct(_) | Definition::Enum(_) | Definition::Union(_) => { - unreachable!("insert_initial_call_polymorph data for non-procedure type"); - }, - }; + let definition = &ctx.heap[call.procedure]; + debug_assert_eq!(poly_args.len(), definition.poly_vars.len()); - let mut parameter_types = Vec::with_capacity(parameters.len()); - for parameter_id in parameters.clone().into_iter() { // TODO: @Performance @Now + let mut parameter_types = Vec::with_capacity(definition.parameters.len()); + let parameter_section = self.var_buffer.start_section_initialized(&definition.parameters); + for parameter_id in parameter_section.iter_copied() { let param = &ctx.heap[parameter_id]; parameter_types.push(self.determine_inference_type_from_parser_type_elements(¶m.parser_type.elements, false)); } + parameter_section.forget(); - let return_type = match returned { + let return_type = match &definition.return_type { None => { // Component, so returns a "Void" + debug_assert_ne!(definition.kind, ProcedureKind::Function); InferenceType::new(false, true, vec![InferenceTypePart::Void]) }, Some(returned) => { - debug_assert_eq!(returned.len(), 1); // TODO: @ReturnTypes - let returned = &returned[0]; + debug_assert_eq!(definition.kind, ProcedureKind::Function); self.determine_inference_type_from_parser_type_elements(&returned.elements, false) } }; - self.extra_data[extra_data_idx as usize] = ExtraData{ - expr_id: call_id.upcast(), - definition_id: call.definition, + let extra_data_idx = self.poly_data.len() as PolyDataIndex; + self.poly_data.push(PolyData { + first_rule_application: true, + definition_id: call.procedure.upcast(), poly_vars: poly_args, - embedded: parameter_types, - returned: return_type - }; + expr_types: PolyDataTypes { + associated: parameter_types, + returned: return_type + } + }); + return extra_data_idx } fn insert_initial_struct_polymorph_data( &mut self, ctx: &mut Ctx, lit_id: LiteralExpressionId, - ) { + ) -> PolyDataIndex { use InferenceTypePart as ITP; - let literal = &ctx.heap[lit_id]; - let extra_data_idx = self.expr_types[literal.unique_id_in_definition as usize].extra_data_idx; // TODO: @Temp - debug_assert!(extra_data_idx != -1, "initial struct polymorph data, but no preallocated ExtraData"); let literal = ctx.heap[lit_id].value.as_struct(); // Handle polymorphic arguments @@ -3501,13 +3547,18 @@ impl PassTyping { debug_assert_eq!(parts.len(), parts_reserved); let return_type = InferenceType::new(!poly_args.is_empty(), return_type_done, parts); - self.extra_data[extra_data_idx as usize] = ExtraData{ - expr_id: lit_id.upcast(), + let extra_data_index = self.poly_data.len() as PolyDataIndex; + self.poly_data.push(PolyData { + first_rule_application: true, definition_id: literal.definition, poly_vars: poly_args, - embedded: embedded_types, - returned: return_type, - }; + expr_types: PolyDataTypes { + associated: embedded_types, + returned: return_type, + }, + }); + + return extra_data_index } /// Inserts the extra polymorphic data struct for enum expressions. These @@ -3515,11 +3566,8 @@ impl PassTyping { /// the use of the enum. fn insert_initial_enum_polymorph_data( &mut self, ctx: &Ctx, lit_id: LiteralExpressionId - ) { + ) -> PolyDataIndex { use InferenceTypePart as ITP; - let literal = &ctx.heap[lit_id]; - let extra_data_idx = self.expr_types[literal.unique_id_in_definition as usize].extra_data_idx; // TODO: @Temp - debug_assert!(extra_data_idx != -1, "initial enum polymorph data, but no preallocated ExtraData"); let literal = ctx.heap[lit_id].value.as_enum(); // Handle polymorphic arguments to the enum @@ -3548,24 +3596,26 @@ impl PassTyping { debug_assert_eq!(parts.len(), parts_reserved); let enum_type = InferenceType::new(!poly_args.is_empty(), enum_type_done, parts); - self.extra_data[extra_data_idx as usize] = ExtraData{ - expr_id: lit_id.upcast(), + let extra_data_index = self.poly_data.len() as PolyDataIndex; + self.poly_data.push(PolyData { + first_rule_application: true, definition_id: literal.definition, poly_vars: poly_args, - embedded: Vec::new(), - returned: enum_type, - }; + expr_types: PolyDataTypes { + associated: Vec::new(), + returned: enum_type, + }, + }); + + return extra_data_index; } /// Inserts the extra polymorphic data struct for unions. The polymorphic /// arguments may be partially determined from embedded values in the union. fn insert_initial_union_polymorph_data( &mut self, ctx: &Ctx, lit_id: LiteralExpressionId - ) { + ) -> PolyDataIndex { use InferenceTypePart as ITP; - let literal = &ctx.heap[lit_id]; - let extra_data_idx = self.expr_types[literal.unique_id_in_definition as usize].extra_data_idx; // TODO: @Temp - debug_assert!(extra_data_idx != -1, "initial union polymorph data, but no preallocated ExtraData"); let literal = ctx.heap[lit_id].value.as_union(); // Construct the polymorphic variables @@ -3609,30 +3659,30 @@ impl PassTyping { debug_assert_eq!(parts_reserved, parts.len()); let union_type = InferenceType::new(!poly_args.is_empty(), union_type_done, parts); - self.extra_data[extra_data_idx as usize] = ExtraData{ - expr_id: lit_id.upcast(), + let extra_data_index = self.poly_data.len() as isize; + self.poly_data.push(PolyData { + first_rule_application: true, definition_id: literal.definition, poly_vars: poly_args, - embedded, - returned: union_type - }; + expr_types: PolyDataTypes { + associated: embedded, + returned: union_type, + }, + }); + + return extra_data_index; } /// Inserts the extra polymorphic data struct. Assumes that the select /// expression's referenced (definition_id, field_idx) has been resolved. fn insert_initial_select_polymorph_data( - &mut self, ctx: &Ctx, select_id: SelectExpressionId, struct_def_id: DefinitionId - ) { + &mut self, ctx: &Ctx, node_index: InferNodeIndex, struct_def_id: DefinitionId + ) -> PolyDataIndex { use InferenceTypePart as ITP; - // Retrieve relevant data - let expr = &ctx.heap[select_id]; - let expr_type = &self.expr_types[expr.unique_id_in_definition as usize]; - let field_idx = expr_type.field_or_monomorph_idx as usize; - let extra_data_idx = expr_type.extra_data_idx; // TODO: @Temp - debug_assert!(extra_data_idx != -1, "initial select polymorph data, but no preallocated ExtraData"); - let definition = ctx.heap[struct_def_id].as_struct(); + let node = &self.infer_nodes[node_index]; + let field_index = node.field_index as usize; // Generate initial polyvar types and struct type // TODO: @Performance: we can immediately set the polyvars of the subject's struct type @@ -3652,14 +3702,20 @@ impl PassTyping { debug_assert_eq!(struct_parts.len(), struct_parts_reserved); // Generate initial field type - let field_type = self.determine_inference_type_from_parser_type_elements(&definition.fields[field_idx].parser_type.elements, false); - self.extra_data[extra_data_idx as usize] = ExtraData{ - expr_id: select_id.upcast(), + let field_type = self.determine_inference_type_from_parser_type_elements(&definition.fields[field_index].parser_type.elements, false); + + let extra_data_index = self.poly_data.len() as PolyDataIndex; + self.poly_data.push(PolyData { + first_rule_application: true, definition_id: struct_def_id, poly_vars, - embedded: vec![InferenceType::new(num_poly_vars != 0, num_poly_vars == 0, struct_parts)], - returned: field_type - }; + expr_types: PolyDataTypes { + associated: vec![InferenceType::new(num_poly_vars != 0, num_poly_vars == 0, struct_parts)], + returned: field_type, + }, + }); + + return extra_data_index; } /// Determines the initial InferenceType from the provided ParserType. This @@ -3728,7 +3784,7 @@ impl PassTyping { if use_definitions_known_poly_args { // Refers to polymorphic argument on procedure we're currently processing. // This argument is already known. - debug_assert_eq!(*belongs_to_definition, self.definition_type.definition_id()); + debug_assert_eq!(*belongs_to_definition, self.procedure_id.upcast()); debug_assert!((poly_arg_idx as usize) < self.poly_vars.len()); Self::determine_inference_type_from_concrete_type( @@ -3782,6 +3838,7 @@ impl PassTyping { CTP::Slice => parser_type.push(ITP::Slice), CTP::Input => parser_type.push(ITP::Input), CTP::Output => parser_type.push(ITP::Output), + CTP::Pointer => unreachable!("pointer type during concrete to inference type conversion"), CTP::Tuple(num) => parser_type.push(ITP::Tuple(*num)), CTP::Instance(id, num) => parser_type.push(ITP::Instance(*id, *num)), CTP::Function(_, _) => unreachable!("function type during concrete to inference type conversion"), @@ -3796,41 +3853,39 @@ impl PassTyping { /// But the expression type was already set due to our parent (e.g. an /// "if statement" or a "logical not" always expecting a boolean) fn construct_expr_type_error( - &self, ctx: &Ctx, expr_id: ExpressionId, arg_id: ExpressionId + &self, ctx: &Ctx, expr_index: InferNodeIndex, arg_index: InferNodeIndex ) -> ParseError { // TODO: Expand and provide more meaningful information for humans - let expr = &ctx.heap[expr_id]; - let arg_expr = &ctx.heap[arg_id]; - let expr_idx = expr.get_unique_id_in_definition(); - let arg_expr_idx = arg_expr.get_unique_id_in_definition(); - let expr_type = &self.expr_types[expr_idx as usize].expr_type; - let arg_type = &self.expr_types[arg_expr_idx as usize].expr_type; + let expr_node = &self.infer_nodes[expr_index]; + let arg_node = &self.infer_nodes[arg_index]; + + let expr = &ctx.heap[expr_node.expr_id]; + let arg = &ctx.heap[arg_node.expr_id]; return ParseError::new_error_at_span( &ctx.module().source, expr.operation_span(), format!( "incompatible types: this expression expected a '{}'", - expr_type.display_name(&ctx.heap) + expr_node.expr_type.display_name(&ctx.heap) ) ).with_info_at_span( - &ctx.module().source, arg_expr.full_span(), format!( + &ctx.module().source, arg.full_span(), format!( "but this expression yields a '{}'", - arg_type.display_name(&ctx.heap) + arg_node.expr_type.display_name(&ctx.heap) ) ) } fn construct_arg_type_error( - &self, ctx: &Ctx, expr_id: ExpressionId, - arg1_id: ExpressionId, arg2_id: ExpressionId + &self, ctx: &Ctx, expr_index: InferNodeIndex, + arg1_index: InferNodeIndex, arg2_index: InferNodeIndex ) -> ParseError { - let expr = &ctx.heap[expr_id]; - let arg1 = &ctx.heap[arg1_id]; - let arg2 = &ctx.heap[arg2_id]; + let arg1_node = &self.infer_nodes[arg1_index]; + let arg2_node = &self.infer_nodes[arg2_index]; - let arg1_idx = arg1.get_unique_id_in_definition(); - let arg1_type = &self.expr_types[arg1_idx as usize].expr_type; - let arg2_idx = arg2.get_unique_id_in_definition(); - let arg2_type = &self.expr_types[arg2_idx as usize].expr_type; + let expr_id = self.infer_nodes[expr_index].expr_id; + let expr = &ctx.heap[expr_id]; + let arg1 = &ctx.heap[arg1_node.expr_id]; + let arg2 = &ctx.heap[arg2_node.expr_id]; return ParseError::new_error_str_at_span( &ctx.module().source, expr.operation_span(), @@ -3838,22 +3893,22 @@ impl PassTyping { ).with_info_at_span( &ctx.module().source, arg1.full_span(), format!( "Because this expression has type '{}'", - arg1_type.display_name(&ctx.heap) + arg1_node.expr_type.display_name(&ctx.heap) ) ).with_info_at_span( &ctx.module().source, arg2.full_span(), format!( "But this expression has type '{}'", - arg2_type.display_name(&ctx.heap) + arg2_node.expr_type.display_name(&ctx.heap) ) ) } fn construct_template_type_error( - &self, ctx: &Ctx, expr_id: ExpressionId, template: &[InferenceTypePart] + &self, ctx: &Ctx, node_index: InferNodeIndex, template: &[InferenceTypePart] ) -> ParseError { - let expr = &ctx.heap[expr_id]; - let expr_idx = expr.get_unique_id_in_definition(); - let expr_type = &self.expr_types[expr_idx as usize].expr_type; + let node = &self.infer_nodes[node_index]; + let expr = &ctx.heap[node.expr_id]; + let expr_type = &node.expr_type; return ParseError::new_error_at_span( &ctx.module().source, expr.full_span(), format!( @@ -3864,6 +3919,29 @@ impl PassTyping { ) } + fn construct_variable_type_error( + &self, ctx: &Ctx, node_index: InferNodeIndex, + ) -> ParseError { + let node = &self.infer_nodes[node_index]; + let rule = node.inference_rule.as_variable_expr(); + + let var_data = &self.var_data[rule.var_data_index]; + let var_decl = &ctx.heap[var_data.var_id]; + let var_expr = &ctx.heap[node.expr_id]; + + return ParseError::new_error_at_span( + &ctx.module().source, var_decl.identifier.span, format!( + "conflicting types for this variable, previously assigned the type '{}'", + var_data.var_type.display_name(&ctx.heap) + ) + ).with_info_at_span( + &ctx.module().source, var_expr.full_span(), format!( + "but inferred to have incompatible type '{}' here", + node.expr_type.display_name(&ctx.heap) + ) + ); + } + /// Constructs a human interpretable error in the case that type inference /// on a polymorphic variable to a function call or literal construction /// failed. This may only be caused by a pair of inference types (which may @@ -3875,7 +3953,7 @@ impl PassTyping { /// We assume that the expression is a function call or a struct literal, /// and that an actual error has occurred. fn construct_poly_arg_error( - ctx: &Ctx, poly_data: &ExtraData, expr_id: ExpressionId + ctx: &Ctx, poly_data: &PolyData, expr_id: ExpressionId ) -> ParseError { // Helper function to check for polymorph mismatch between two inference // types. @@ -3927,7 +4005,7 @@ impl PassTyping { } // Helper function to construct initial error - fn construct_main_error(ctx: &Ctx, poly_data: &ExtraData, poly_var_idx: u32, expr: &Expression) -> ParseError { + fn construct_main_error(ctx: &Ctx, poly_data: &PolyData, poly_var_idx: u32, expr: &Expression) -> ParseError { match expr { Expression::Call(expr) => { let (poly_var, func_name) = get_poly_var_and_definition_name(ctx, poly_var_idx, poly_data.definition_id); @@ -3996,7 +4074,7 @@ impl PassTyping { // - check return type with itself if let Some((poly_idx, section_a, section_b)) = has_poly_mismatch( - &poly_data.returned, &poly_data.returned + &poly_data.expr_types.returned, &poly_data.expr_types.returned ) { return construct_main_error(ctx, poly_data, poly_idx, expr) .with_info_at_span( @@ -4010,8 +4088,8 @@ impl PassTyping { } // - check arguments with each other argument and with return type - for (arg_a_idx, arg_a) in poly_data.embedded.iter().enumerate() { - for (arg_b_idx, arg_b) in poly_data.embedded.iter().enumerate() { + for (arg_a_idx, arg_a) in poly_data.expr_types.associated.iter().enumerate() { + for (arg_b_idx, arg_b) in poly_data.expr_types.associated.iter().enumerate() { if arg_b_idx > arg_a_idx { break; } @@ -4047,7 +4125,7 @@ impl PassTyping { } // Check with return type - if let Some((poly_idx, section_arg, section_ret)) = has_poly_mismatch(arg_a, &poly_data.returned) { + if let Some((poly_idx, section_arg, section_ret)) = has_poly_mismatch(arg_a, &poly_data.expr_types.returned) { let arg = &ctx.heap[expr_args[arg_a_idx]]; return construct_main_error(ctx, poly_data, poly_idx, expr) .with_info_at_span( @@ -4068,7 +4146,7 @@ impl PassTyping { // Now check against the explicitly specified polymorphic variables (if // any). - for (arg_idx, arg) in poly_data.embedded.iter().enumerate() { + for (arg_idx, arg) in poly_data.expr_types.associated.iter().enumerate() { if let Some((poly_idx, poly_section, arg_section)) = has_explicit_poly_mismatch(&poly_data.poly_vars, arg) { let arg = &ctx.heap[expr_args[arg_idx]]; return construct_main_error(ctx, poly_data, poly_idx, expr) @@ -4082,7 +4160,7 @@ impl PassTyping { } } - if let Some((poly_idx, poly_section, ret_section)) = has_explicit_poly_mismatch(&poly_data.poly_vars, &poly_data.returned) { + if let Some((poly_idx, poly_section, ret_section)) = has_explicit_poly_mismatch(&poly_data.poly_vars, &poly_data.expr_types.returned) { return construct_main_error(ctx, poly_data, poly_idx, expr) .with_info_at_span( &ctx.module().source, expr.full_span(), format!( @@ -4098,6 +4176,21 @@ impl PassTyping { } } +fn get_tuple_size_from_inference_type(inference_type: &InferenceType) -> Result, ()> { + for part in &inference_type.parts { + if part.is_marker() { continue; } + if !part.is_concrete() { break; } + + if let InferenceTypePart::Tuple(size) = part { + return Ok(Some(*size)); + } else { + return Err(()); // not a tuple! + } + } + + return Ok(None); +} + #[cfg(test)] mod tests { use super::*; diff --git a/src/protocol/parser/pass_validation_linking.rs b/src/protocol/parser/pass_validation_linking.rs index dd755059ac902761a6752d00a6b036909d572d6a..697c6770543d5178b575178b9d5f84d83ac7556e 100644 --- a/src/protocol/parser/pass_validation_linking.rs +++ b/src/protocol/parser/pass_validation_linking.rs @@ -42,37 +42,18 @@ use crate::protocol::parser::symbol_table::*; use crate::protocol::parser::type_table::*; use super::visitor::{ - BUFFER_INIT_CAPACITY, + BUFFER_INIT_CAP_SMALL, + BUFFER_INIT_CAP_LARGE, Ctx, Visitor, VisitorResult }; use crate::protocol::parser::ModuleCompilationPhase; -#[derive(PartialEq, Eq)] -enum DefinitionType { - Primitive(ComponentDefinitionId), - Composite(ComponentDefinitionId), - Function(FunctionDefinitionId) -} - -impl DefinitionType { - fn is_primitive(&self) -> bool { if let Self::Primitive(_) = self { true } else { false } } - fn is_composite(&self) -> bool { if let Self::Composite(_) = self { true } else { false } } - fn is_function(&self) -> bool { if let Self::Function(_) = self { true } else { false } } - fn definition_id(&self) -> DefinitionId { - match self { - DefinitionType::Primitive(v) => v.upcast(), - DefinitionType::Composite(v) => v.upcast(), - DefinitionType::Function(v) => v.upcast(), - } - } -} - struct ControlFlowStatement { in_sync: SynchronousStatementId, in_while: WhileStatementId, - in_scope: Scope, + in_scope: ScopeId, statement: StatementId, // of 'break', 'continue' or 'goto' } @@ -101,8 +82,9 @@ pub(crate) struct PassValidationLinking { in_binding_expr_lhs: bool, // Traversal state, current scope (which can be used to find the parent // scope) and the definition variant we are considering. - cur_scope: Scope, - def_type: DefinitionType, + cur_scope: ScopeId, + proc_id: ProcedureDefinitionId, + proc_kind: ProcedureKind, // "Trailing" traversal state, set be child/prev stmt/expr used by next one prev_stmt: StatementId, expr_parent: ExpressionParent, @@ -111,8 +93,7 @@ pub(crate) struct PassValidationLinking { // used for the error's position must_be_assignable: Option, // Keeping track of relative positions and unique IDs. - relative_pos_in_block: i32, // of statements: to determine when variables are visible - next_expr_index: i32, // to arrive at a unique ID for all expressions within a definition + relative_pos_in_parent: i32, // of statements: to determine when variables are visible // Control flow statements that require label resolving control_flow_stmts: Vec, // Various temporary buffers for traversal. Essentially working around @@ -122,6 +103,7 @@ pub(crate) struct PassValidationLinking { definition_buffer: ScopedBuffer, statement_buffer: ScopedBuffer, expression_buffer: ScopedBuffer, + scope_buffer: ScopedBuffer, } impl PassValidationLinking { @@ -134,18 +116,19 @@ impl PassValidationLinking { in_test_expr: StatementId::new_invalid(), in_binding_expr: BindingExpressionId::new_invalid(), in_binding_expr_lhs: false, - cur_scope: Scope::new_invalid(), + cur_scope: ScopeId::new_invalid(), prev_stmt: StatementId::new_invalid(), expr_parent: ExpressionParent::None, - def_type: DefinitionType::Function(FunctionDefinitionId::new_invalid()), + proc_id: ProcedureDefinitionId::new_invalid(), + proc_kind: ProcedureKind::Function, must_be_assignable: None, - relative_pos_in_block: 0, - next_expr_index: 0, - control_flow_stmts: Vec::with_capacity(32), - variable_buffer: ScopedBuffer::with_capacity(128), - definition_buffer: ScopedBuffer::with_capacity(128), - statement_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAPACITY), - expression_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAPACITY), + relative_pos_in_parent: 0, + control_flow_stmts: Vec::with_capacity(BUFFER_INIT_CAP_SMALL), + variable_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL), + definition_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL), + statement_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_LARGE), + expression_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_LARGE), + scope_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL), } } @@ -156,13 +139,13 @@ impl PassValidationLinking { self.in_test_expr = StatementId::new_invalid(); self.in_binding_expr = BindingExpressionId::new_invalid(); self.in_binding_expr_lhs = false; - self.cur_scope = Scope::new_invalid(); - self.def_type = DefinitionType::Function(FunctionDefinitionId::new_invalid()); + self.cur_scope = ScopeId::new_invalid(); + self.proc_id = ProcedureDefinitionId::new_invalid(); + self.proc_kind = ProcedureKind::Function; self.prev_stmt = StatementId::new_invalid(); self.expr_parent = ExpressionParent::None; self.must_be_assignable = None; - self.relative_pos_in_block = 0; - self.next_expr_index = 0; + self.relative_pos_in_parent = 0; self.control_flow_stmts.clear(); } } @@ -203,65 +186,31 @@ impl Visitor for PassValidationLinking { // Definition visitors //-------------------------------------------------------------------------- - fn visit_component_definition(&mut self, ctx: &mut Ctx, id: ComponentDefinitionId) -> VisitorResult { + fn visit_procedure_definition(&mut self, ctx: &mut Ctx, id: ProcedureDefinitionId) -> VisitorResult { self.reset_state(); - self.def_type = match &ctx.heap[id].variant { - ComponentVariant::Primitive => DefinitionType::Primitive(id), - ComponentVariant::Composite => DefinitionType::Composite(id), - }; - self.cur_scope = Scope::Definition(id.upcast()); - self.expr_parent = ExpressionParent::None; - - // Visit parameters and assign a unique scope ID let definition = &ctx.heap[id]; - let body_id = definition.body; - let section = self.variable_buffer.start_section_initialized(&definition.parameters); - for variable_idx in 0..section.len() { - let variable_id = section[variable_idx]; - let variable = &mut ctx.heap[variable_id]; - variable.unique_id_in_scope = variable_idx as i32; - } - section.forget(); - - // Visit statements in component body - self.visit_block_stmt(ctx, body_id)?; - - // Assign total number of expressions and assign an in-block unique ID - // to each of the locals in the procedure. - ctx.heap[id].num_expressions_in_body = self.next_expr_index; - self.visit_definition_and_assign_local_ids(ctx, id.upcast()); - self.resolve_pending_control_flow_targets(ctx)?; - - Ok(()) - } - - fn visit_function_definition(&mut self, ctx: &mut Ctx, id: FunctionDefinitionId) -> VisitorResult { - self.reset_state(); - - // Set internal statement indices - self.def_type = DefinitionType::Function(id); - self.cur_scope = Scope::Definition(id.upcast()); + self.proc_id = id; + self.proc_kind = definition.kind; self.expr_parent = ExpressionParent::None; - // Visit parameters and assign a unique scope ID + // Visit parameters + let scope_id = definition.scope; + let old_scope = self.push_scope(ctx, true, scope_id); + let definition = &ctx.heap[id]; let body_id = definition.body; let section = self.variable_buffer.start_section_initialized(&definition.parameters); for variable_idx in 0..section.len() { let variable_id = section[variable_idx]; - let variable = &mut ctx.heap[variable_id]; - variable.unique_id_in_scope = variable_idx as i32; + self.checked_at_single_scope_add_local(ctx, self.cur_scope, -1, variable_id)?; } section.forget(); // Visit statements in function body self.visit_block_stmt(ctx, body_id)?; + self.pop_scope(old_scope); - // Assign total number of expressions and assign an in-block unique ID - // to each of the locals in the procedure. - ctx.heap[id].num_expressions_in_body = self.next_expr_index; - self.visit_definition_and_assign_local_ids(ctx, id.upcast()); self.resolve_pending_control_flow_targets(ctx)?; Ok(()) @@ -272,27 +221,25 @@ impl Visitor for PassValidationLinking { //-------------------------------------------------------------------------- fn visit_block_stmt(&mut self, ctx: &mut Ctx, id: BlockStatementId) -> VisitorResult { - let old_scope = self.push_statement_scope(ctx, Scope::Regular(id)); - - // Set end of block + // Get end of block let block_stmt = &ctx.heap[id]; let end_block_id = block_stmt.end_block; - - // Copy statement IDs into buffer + let scope_id = block_stmt.scope; // Traverse statements in block let statement_section = self.statement_buffer.start_section_initialized(&block_stmt.statements); + let old_scope = self.push_scope(ctx, false, scope_id); assign_and_replace_next_stmt!(self, ctx, id.upcast()); for stmt_idx in 0..statement_section.len() { - self.relative_pos_in_block = stmt_idx as i32; + self.relative_pos_in_parent = stmt_idx as i32; self.visit_stmt(ctx, statement_section[stmt_idx])?; } statement_section.forget(); assign_and_replace_next_stmt!(self, ctx, end_block_id.upcast()); - self.pop_statement_scope(old_scope); + self.pop_scope(old_scope); Ok(()) } @@ -301,7 +248,7 @@ impl Visitor for PassValidationLinking { let expr_id = stmt.initial_expr; let variable_id = stmt.variable; - self.checked_add_local(ctx, self.cur_scope, self.relative_pos_in_block, variable_id)?; + self.checked_add_local(ctx, self.cur_scope, self.relative_pos_in_parent, variable_id)?; assign_and_replace_next_stmt!(self, ctx, id.upcast().upcast()); debug_assert_eq!(self.expr_parent, ExpressionParent::None); @@ -317,8 +264,8 @@ impl Visitor for PassValidationLinking { let from_id = stmt.from; let to_id = stmt.to; - self.checked_add_local(ctx, self.cur_scope, self.relative_pos_in_block, from_id)?; - self.checked_add_local(ctx, self.cur_scope, self.relative_pos_in_block, to_id)?; + self.checked_add_local(ctx, self.cur_scope, self.relative_pos_in_parent, from_id)?; + self.checked_add_local(ctx, self.cur_scope, self.relative_pos_in_parent, to_id)?; assign_and_replace_next_stmt!(self, ctx, id.upcast().upcast()); Ok(()) @@ -328,7 +275,7 @@ impl Visitor for PassValidationLinking { let stmt = &ctx.heap[id]; let body_id = stmt.body; - self.checked_add_label(ctx, self.relative_pos_in_block, self.in_sync, id)?; + self.checked_add_label(ctx, self.relative_pos_in_parent, self.in_sync, id)?; self.visit_stmt(ctx, body_id)?; Ok(()) @@ -338,8 +285,8 @@ impl Visitor for PassValidationLinking { let if_stmt = &ctx.heap[id]; let end_if_id = if_stmt.end_if; let test_expr_id = if_stmt.test; - let true_stmt_id = if_stmt.true_body; - let false_stmt_id = if_stmt.false_body; + let true_case = if_stmt.true_case; + let false_case = if_stmt.false_case; // Visit test expression debug_assert_eq!(self.expr_parent, ExpressionParent::None); @@ -356,11 +303,15 @@ impl Visitor for PassValidationLinking { // test expression, not on if-statement itself. Hence the if statement // does not have a static subsequent statement. assign_then_erase_next_stmt!(self, ctx, id.upcast()); - self.visit_block_stmt(ctx, true_stmt_id)?; + let old_scope = self.push_scope(ctx, false, true_case.scope); + self.visit_stmt(ctx, true_case.body)?; + self.pop_scope(old_scope); assign_then_erase_next_stmt!(self, ctx, end_if_id.upcast()); - if let Some(false_id) = false_stmt_id { - self.visit_block_stmt(ctx, false_id)?; + if let Some(false_case) = false_case { + let old_scope = self.push_scope(ctx, false, false_case.scope); + self.visit_stmt(ctx, false_case.body)?; + self.pop_scope(old_scope); assign_then_erase_next_stmt!(self, ctx, end_if_id.upcast()); } @@ -373,6 +324,7 @@ impl Visitor for PassValidationLinking { let end_while_id = stmt.end_while; let test_expr_id = stmt.test; let body_stmt_id = stmt.body; + let scope_id = stmt.scope; let old_while = self.in_while; self.in_while = id; @@ -389,7 +341,9 @@ impl Visitor for PassValidationLinking { assign_then_erase_next_stmt!(self, ctx, id.upcast()); self.expr_parent = ExpressionParent::None; - self.visit_block_stmt(ctx, body_stmt_id)?; + let old_scope = self.push_scope(ctx, false, scope_id); + self.visit_stmt(ctx, body_stmt_id)?; + self.pop_scope(old_scope); self.in_while = old_while; // Link final entry in while's block statement back to the while. The @@ -430,6 +384,8 @@ impl Visitor for PassValidationLinking { let sync_stmt = &ctx.heap[id]; let end_sync_id = sync_stmt.end_sync; let cur_sync_span = sync_stmt.span; + let scope_id = sync_stmt.scope; + if !self.in_sync.is_invalid() { // Nested synchronous statement let old_sync_span = ctx.heap[self.in_sync].span; @@ -440,7 +396,7 @@ impl Visitor for PassValidationLinking { )); } - if !self.def_type.is_primitive() { + if self.proc_kind != ProcedureKind::Primitive { return Err(ParseError::new_error_str_at_span( &ctx.module().source, cur_sync_span, "synchronous statements may only be used in primitive components" @@ -456,9 +412,9 @@ impl Visitor for PassValidationLinking { let sync_body = ctx.heap[id].body; debug_assert!(self.in_sync.is_invalid()); self.in_sync = id; - let old_scope = self.push_statement_scope(ctx, Scope::Synchronous(id, sync_body)); - self.visit_block_stmt(ctx, sync_body)?; - self.pop_statement_scope(old_scope); + let old_scope = self.push_scope(ctx, false, scope_id); + self.visit_stmt(ctx, sync_body)?; + self.pop_scope(old_scope); assign_and_replace_next_stmt!(self, ctx, end_sync_id.upcast()); self.in_sync = SynchronousStatementId::new_invalid(); @@ -484,11 +440,11 @@ impl Visitor for PassValidationLinking { // does not have a single static subsequent statement. It forks and then // each fork has a different next statement. assign_then_erase_next_stmt!(self, ctx, id.upcast()); - self.visit_block_stmt(ctx, left_body_id)?; + self.visit_stmt(ctx, left_body_id)?; assign_then_erase_next_stmt!(self, ctx, end_fork_id.upcast()); if let Some(right_body_id) = right_body_id { - self.visit_block_stmt(ctx, right_body_id)?; + self.visit_stmt(ctx, right_body_id)?; assign_then_erase_next_stmt!(self, ctx, end_fork_id.upcast()); } @@ -497,6 +453,10 @@ impl Visitor for PassValidationLinking { } fn visit_select_stmt(&mut self, ctx: &mut Ctx, id: SelectStatementId) -> VisitorResult { + let select_stmt = &mut ctx.heap[id]; + select_stmt.relative_pos_in_parent = self.relative_pos_in_parent; + self.relative_pos_in_parent += 1; + let select_stmt = &ctx.heap[id]; let end_select_id = select_stmt.end_select; @@ -508,7 +468,7 @@ impl Visitor for PassValidationLinking { )); } - if !self.def_type.is_primitive() { + if self.proc_kind != ProcedureKind::Primitive { return Err(ParseError::new_error_str_at_span( &ctx.module().source, select_stmt.span, "select statements may only be used in primitive components" @@ -517,38 +477,39 @@ impl Visitor for PassValidationLinking { // Visit the various arms in the select block let mut case_stmt_ids = self.statement_buffer.start_section(); + let mut case_scope_ids = self.scope_buffer.start_section(); let num_cases = select_stmt.cases.len(); for case in &select_stmt.cases { - // Note: we add both to the buffer, retrieve them later in indexed - // fashion + // We add them in pairs, so the subsequent for-loop retrieves in pairs case_stmt_ids.push(case.guard); - case_stmt_ids.push(case.block.upcast()); + case_stmt_ids.push(case.body); + case_scope_ids.push(case.scope); } assign_then_erase_next_stmt!(self, ctx, id.upcast()); - for idx in 0..num_cases { - let base_idx = 2 * idx; - let guard_id = case_stmt_ids[base_idx ]; - let arm_block_id = case_stmt_ids[base_idx + 1]; - debug_assert_eq!(ctx.heap[arm_block_id].as_block().this.upcast(), arm_block_id); // backwards way of saying arm_block_id is a BlockStatementId - let arm_block_id = BlockStatementId(arm_block_id); + for index in 0..num_cases { + let base_index = 2 * index; + let guard_id = case_stmt_ids[base_index]; + let case_body_id = case_stmt_ids[base_index + 1]; + let case_scope_id = case_scope_ids[index]; // The guard statement ends up belonging to the block statement // following the arm. The reason we parse it separately is to // extract all of the "get" calls. - let old_scope = self.push_statement_scope(ctx, Scope::Regular(arm_block_id)); + let old_scope = self.push_scope(ctx, false, case_scope_id); // Visit the guard of this arm debug_assert!(self.in_select_guard.is_invalid()); self.in_select_guard = id; - self.in_select_arm = idx as u32; + self.in_select_arm = index as u32; self.visit_stmt(ctx, guard_id)?; self.in_select_guard = SelectStatementId::new_invalid(); // Visit the code associated with the guard - self.visit_block_stmt(ctx, arm_block_id)?; - self.pop_statement_scope(old_scope); + self.relative_pos_in_parent += 1; + self.visit_stmt(ctx, case_body_id)?; + self.pop_scope(old_scope); // Link up last statement in block to EndSelect assign_then_erase_next_stmt!(self, ctx, end_select_id.upcast()); @@ -562,7 +523,7 @@ impl Visitor for PassValidationLinking { fn visit_return_stmt(&mut self, ctx: &mut Ctx, id: ReturnStatementId) -> VisitorResult { // Check if "return" occurs within a function let stmt = &ctx.heap[id]; - if !self.def_type.is_function() { + if self.proc_kind != ProcedureKind::Function { return Err(ParseError::new_error_str_at_span( &ctx.module().source, stmt.span, "return statements may only appear in function bodies" @@ -594,7 +555,7 @@ impl Visitor for PassValidationLinking { fn visit_new_stmt(&mut self, ctx: &mut Ctx, id: NewStatementId) -> VisitorResult { // Make sure the new statement occurs inside a composite component - if !self.def_type.is_composite() { + if self.proc_kind != ProcedureKind::Composite { let new_stmt = &ctx.heap[id]; return Err(ParseError::new_error_str_at_span( &ctx.module().source, new_stmt.span, @@ -656,8 +617,6 @@ impl Visitor for PassValidationLinking { let right_expr_id = assignment_expr.right; let old_expr_parent = self.expr_parent; assignment_expr.parent = old_expr_parent; - assignment_expr.unique_id_in_definition = self.next_expr_index; - self.next_expr_index += 1; self.expr_parent = ExpressionParent::Expression(upcast_id, 0); self.must_be_assignable = Some(assignment_expr.operator_span); @@ -746,8 +705,6 @@ impl Visitor for PassValidationLinking { let old_expr_parent = self.expr_parent; binding_expr.parent = old_expr_parent; - binding_expr.unique_id_in_definition = self.next_expr_index; - self.next_expr_index += 1; self.in_binding_expr = id; // Perform preliminary check on children: binding expressions only make @@ -801,8 +758,6 @@ impl Visitor for PassValidationLinking { let old_expr_parent = self.expr_parent; conditional_expr.parent = old_expr_parent; - conditional_expr.unique_id_in_definition = self.next_expr_index; - self.next_expr_index += 1; self.expr_parent = ExpressionParent::Expression(upcast_id, 0); self.visit_expr(ctx, test_expr_id)?; @@ -830,8 +785,6 @@ impl Visitor for PassValidationLinking { let old_expr_parent = self.expr_parent; binary_expr.parent = old_expr_parent; - binary_expr.unique_id_in_definition = self.next_expr_index; - self.next_expr_index += 1; self.expr_parent = ExpressionParent::Expression(upcast_id, 0); self.visit_expr(ctx, left_expr_id)?; @@ -854,8 +807,6 @@ impl Visitor for PassValidationLinking { let old_expr_parent = self.expr_parent; unary_expr.parent = old_expr_parent; - unary_expr.unique_id_in_definition = self.next_expr_index; - self.next_expr_index += 1; self.expr_parent = ExpressionParent::Expression(id.upcast(), 0); self.visit_expr(ctx, expr_id)?; @@ -873,8 +824,6 @@ impl Visitor for PassValidationLinking { let old_expr_parent = self.expr_parent; indexing_expr.parent = old_expr_parent; - indexing_expr.unique_id_in_definition = self.next_expr_index; - self.next_expr_index += 1; self.expr_parent = ExpressionParent::Expression(upcast_id, 0); self.visit_expr(ctx, subject_expr_id)?; @@ -906,8 +855,6 @@ impl Visitor for PassValidationLinking { let old_expr_parent = self.expr_parent; slicing_expr.parent = old_expr_parent; - slicing_expr.unique_id_in_definition = self.next_expr_index; - self.next_expr_index += 1; self.expr_parent = ExpressionParent::Expression(upcast_id, 0); self.visit_expr(ctx, subject_expr_id)?; @@ -930,8 +877,6 @@ impl Visitor for PassValidationLinking { let old_expr_parent = self.expr_parent; select_expr.parent = old_expr_parent; - select_expr.unique_id_in_definition = self.next_expr_index; - self.next_expr_index += 1; self.expr_parent = ExpressionParent::Expression(id.upcast(), 0); self.visit_expr(ctx, expr_id)?; @@ -944,8 +889,6 @@ impl Visitor for PassValidationLinking { let literal_expr = &mut ctx.heap[id]; let old_expr_parent = self.expr_parent; literal_expr.parent = old_expr_parent; - literal_expr.unique_id_in_definition = self.next_expr_index; - self.next_expr_index += 1; if let Some(span) = self.must_be_assignable { return Err(ParseError::new_error_str_at_span( @@ -1143,8 +1086,6 @@ impl Visitor for PassValidationLinking { let upcast_id = id.upcast(); let old_expr_parent = self.expr_parent; cast_expr.parent = old_expr_parent; - cast_expr.unique_id_in_definition = self.next_expr_index; - self.next_expr_index += 1; // Recurse into the thing that we're casting self.expr_parent = ExpressionParent::Expression(upcast_id, 0); @@ -1203,7 +1144,7 @@ impl Visitor for PassValidationLinking { Method::Assert => { expecting_wrapping_sync_stmt = true; expecting_no_select_stmt = true; - if self.def_type.is_function() { + if self.proc_kind == ProcedureKind::Function { let call_span = call_expr.func_span; return Err(ParseError::new_error_str_at_span( &ctx.module().source, call_span, @@ -1212,7 +1153,10 @@ impl Visitor for PassValidationLinking { } }, Method::Print => {}, - Method::UserFunction => {}, + Method::SelectStart + | Method::SelectRegisterCasePort + | Method::SelectWait => unreachable!(), // not usable by programmer directly + Method::UserFunction => {} Method::UserComponent => { expecting_wrapping_new_stmt = true; }, @@ -1227,7 +1171,7 @@ impl Visitor for PassValidationLinking { return (span, name); } if expecting_primitive_def { - if !self.def_type.is_primitive() { + if self.proc_kind != ProcedureKind::Primitive { let (call_span, func_name) = get_span_and_name(ctx, id); return Err(ParseError::new_error_at_span( &ctx.module().source, call_span, @@ -1275,11 +1219,10 @@ impl Visitor for PassValidationLinking { } // Check the number of arguments - let call_definition = ctx.types.get_base_definition(&call_expr.definition).unwrap(); + let call_definition = ctx.types.get_base_definition(&call_expr.procedure.upcast()).unwrap(); let num_expected_args = match &call_definition.definition { - DefinedTypeVariant::Function(definition) => definition.arguments.len(), - DefinedTypeVariant::Component(definition) => definition.arguments.len(), - v => unreachable!("encountered {} type in call expression", v.type_class()), + DefinedTypeVariant::Procedure(definition) => definition.arguments.len(), + _ => unreachable!(), }; let num_provided_args = call_expr.arguments.len(); @@ -1300,8 +1243,6 @@ impl Visitor for PassValidationLinking { let section = self.expression_buffer.start_section_initialized(&call_expr.arguments); let old_expr_parent = self.expr_parent; call_expr.parent = old_expr_parent; - call_expr.unique_id_in_definition = self.next_expr_index; - self.next_expr_index += 1; for arg_expr_idx in 0..section.len() { let arg_expr_id = section[arg_expr_idx]; @@ -1326,7 +1267,7 @@ impl Visitor for PassValidationLinking { // Otherwise try to find it if variable_id.is_none() { - variable_id = self.find_variable(ctx, self.relative_pos_in_block, &var_expr.identifier); + variable_id = self.find_variable(ctx, self.relative_pos_in_parent, &var_expr.identifier); } // Otherwise try to see if is a variable introduced by a binding expr @@ -1355,12 +1296,12 @@ impl Visitor for PassValidationLinking { Expression::Literal(lit_expr) => { // Only struct, unions, tuples and arrays can // have subexpressions, so we're always fine - if cfg!(debug_assertions) { + dbg_code!({ match lit_expr.value { Literal::Struct(_) | Literal::Union(_) | Literal::Array(_) | Literal::Tuple(_) => {}, _ => unreachable!(), } - } + }); true }, @@ -1387,8 +1328,8 @@ impl Visitor for PassValidationLinking { // By now we know that this is a valid binding expression. Given // that a binding expression must be nested under an if/while - // statement, we now add the variable to the (implicit) block - // statement following the if/while statement. + // statement, we now add the variable to the scope associated with + // that statement. let bound_identifier = var_expr.identifier.clone(); let bound_variable_id = ctx.heap.alloc_variable(|this| Variable { this, @@ -1401,17 +1342,17 @@ impl Visitor for PassValidationLinking { full_span: bound_identifier.span }, identifier: bound_identifier, - relative_pos_in_block: 0, + relative_pos_in_parent: 0, unique_id_in_scope: -1, }); - let body_stmt_id = match &ctx.heap[self.in_test_expr] { - Statement::If(stmt) => stmt.true_body, - Statement::While(stmt) => stmt.body, + let scope_id = match &ctx.heap[self.in_test_expr] { + Statement::If(stmt) => stmt.true_case.scope, + Statement::While(stmt) => stmt.scope, _ => unreachable!(), }; - let body_scope = Scope::Regular(body_stmt_id); - self.checked_at_single_scope_add_local(ctx, body_scope, -1, bound_variable_id)?; // add at -1 such that first statement can access + + self.checked_at_single_scope_add_local(ctx, scope_id, -1, bound_variable_id)?; // add at -1 such that first statement can find the variable if needed is_binding_target = true; bound_variable_id @@ -1421,8 +1362,6 @@ impl Visitor for PassValidationLinking { var_expr.declaration = Some(variable_id); var_expr.used_as_binding_target = is_binding_target; var_expr.parent = self.expr_parent; - var_expr.unique_id_in_definition = self.next_expr_index; - self.next_expr_index += 1; Ok(()) } @@ -1438,123 +1377,35 @@ impl PassValidationLinking { /// sync statement or select statement's arm) then we won't do anything. /// In all cases the caller must call `pop_statement_scope` with the scope /// and relative scope position returned by this function. - fn push_statement_scope(&mut self, ctx: &mut Ctx, new_scope: Scope) -> (Scope, i32) { - let old_scope = self.cur_scope.clone(); - debug_assert!(new_scope.is_block()); // never call for Definition scope - let is_new_block = if old_scope.is_block() { - old_scope.to_block() != new_scope.to_block() - } else { - true - }; - - if !is_new_block { - // No need to push, but still return old scope, we pretend like we - // replaced it. - debug_assert!(!ctx.heap[new_scope.to_block()].scope_node.parent.is_invalid()); - return (old_scope, self.relative_pos_in_block); - } - - // This is a new block, so link it up - if old_scope.is_block() { - let parent_block = &mut ctx.heap[old_scope.to_block()]; - parent_block.scope_node.nested.push(new_scope); + fn push_scope(&mut self, ctx: &mut Ctx, is_top_level_scope: bool, pushed_scope_id: ScopeId) -> (ScopeId, i32) { + // Set the properties of the pushed scope (it is already created during + // AST construction, but most values are not yet set to their correct + // values) + let old_scope_id = self.cur_scope; + + let scope = &mut ctx.heap[pushed_scope_id]; + if !is_top_level_scope { + scope.parent = Some(old_scope_id); } - self.cur_scope = new_scope; - - let cur_block = &mut ctx.heap[new_scope.to_block()]; - cur_block.scope_node.parent = old_scope; - cur_block.scope_node.relative_pos_in_parent = self.relative_pos_in_block; - - let old_relative_pos = self.relative_pos_in_block; - self.relative_pos_in_block = -1; + scope.relative_pos_in_parent = self.relative_pos_in_parent; + let old_relative_pos = self.relative_pos_in_parent; + self.relative_pos_in_parent = 0; - return (old_scope, old_relative_pos) - } - - fn pop_statement_scope(&mut self, scope_to_restore: (Scope, i32)) { - self.cur_scope = scope_to_restore.0; - self.relative_pos_in_block = scope_to_restore.1; - } - - fn visit_definition_and_assign_local_ids(&mut self, ctx: &mut Ctx, definition_id: DefinitionId) { - let mut var_counter = 0; - - // Set IDs on parameters - let (param_section, body_id) = match &ctx.heap[definition_id] { - Definition::Function(func_def) => ( - self.variable_buffer.start_section_initialized(&func_def.parameters), - func_def.body - ), - Definition::Component(comp_def) => ( - self.variable_buffer.start_section_initialized(&comp_def.parameters), - comp_def.body - ), - _ => unreachable!(), - } ; - - for var_id in param_section.iter_copied() { - let var = &mut ctx.heap[var_id]; - var.unique_id_in_scope = var_counter; - var_counter += 1; + // Link up scopes + if !is_top_level_scope { + let old_scope = &mut ctx.heap[old_scope_id]; + old_scope.nested.push(pushed_scope_id); } - param_section.forget(); - - // Recurse into body - self.visit_block_and_assign_local_ids(ctx, body_id, var_counter); + // Set as current traversal scope, then return old scope + self.cur_scope = pushed_scope_id; + return (old_scope_id, old_relative_pos) } - fn visit_block_and_assign_local_ids(&mut self, ctx: &mut Ctx, block_id: BlockStatementId, mut var_counter: i32) { - let block_stmt = &mut ctx.heap[block_id]; - block_stmt.first_unique_id_in_scope = var_counter; - - let var_section = self.variable_buffer.start_section_initialized(&block_stmt.locals); - let mut scope_section = self.statement_buffer.start_section(); - for child_scope in &block_stmt.scope_node.nested { - debug_assert!(child_scope.is_block(), "found a child scope that is not a block statement"); - scope_section.push(child_scope.to_block().upcast()); - } - - let mut var_idx = 0; - let mut scope_idx = 0; - while var_idx < var_section.len() || scope_idx < scope_section.len() { - let relative_var_pos = if var_idx < var_section.len() { - ctx.heap[var_section[var_idx]].relative_pos_in_block - } else { - i32::MAX - }; - - let relative_scope_pos = if scope_idx < scope_section.len() { - ctx.heap[scope_section[scope_idx]].as_block().scope_node.relative_pos_in_parent - } else { - i32::MAX - }; - - debug_assert!(!(relative_var_pos == i32::MAX && relative_scope_pos == i32::MAX)); - - // In certain cases the relative variable position is the same as - // the scope position (insertion of binding variables). In that case - // the variable should be treated first - if relative_var_pos <= relative_scope_pos { - let var = &mut ctx.heap[var_section[var_idx]]; - var.unique_id_in_scope = var_counter; - var_counter += 1; - var_idx += 1; - } else { - // Boy oh boy - let block_id = ctx.heap[scope_section[scope_idx]].as_block().this; - self.visit_block_and_assign_local_ids(ctx, block_id, var_counter); - scope_idx += 1; - } - } - - var_section.forget(); - scope_section.forget(); - - // Done assigning all IDs, assign the last ID to the block statement scope - let block_stmt = &mut ctx.heap[block_id]; - block_stmt.next_unique_id_in_scope = var_counter; + fn pop_scope(&mut self, scope_to_restore: (ScopeId, i32)) { + self.cur_scope = scope_to_restore.0; + self.relative_pos_in_parent = scope_to_restore.1; } fn resolve_pending_control_flow_targets(&mut self, ctx: &mut Ctx) -> Result<(), ParseError> { @@ -1613,76 +1464,50 @@ impl PassValidationLinking { /// Adds a local variable to the current scope. It will also annotate the /// `Local` in the AST with its relative position in the block. - fn checked_add_local(&mut self, ctx: &mut Ctx, target_scope: Scope, target_relative_pos: i32, id: VariableId) -> Result<(), ParseError> { - debug_assert!(target_scope.is_block()); - let local = &ctx.heap[id]; + fn checked_add_local(&mut self, ctx: &mut Ctx, target_scope_id: ScopeId, target_relative_pos: i32, new_variable_id: VariableId) -> Result<(), ParseError> { + let new_variable = &ctx.heap[new_variable_id]; // We immediately go to the parent scope. We check the target scope // in the call at the end. That is also where we check for collisions // with symbols. - let block = &ctx.heap[target_scope.to_block()]; - let mut scope = block.scope_node.parent; - let mut cur_relative_pos = block.scope_node.relative_pos_in_parent; - loop { - if let Scope::Definition(definition_id) = scope { - // At outer scope, check parameters of function/component - for parameter_id in ctx.heap[definition_id].parameters() { - let parameter = &ctx.heap[*parameter_id]; - if local.identifier == parameter.identifier { - return Err( - ParseError::new_error_str_at_span( - &ctx.module().source, local.identifier.span, "Local variable name conflicts with parameter" - ).with_info_str_at_span( - &ctx.module().source, parameter.identifier.span, "Parameter definition is found here" - ) - ); - } - } - - // No collisions - break; - } - - // If here then the parent scope is a block scope - let block = &ctx.heap[scope.to_block()]; - - for other_local_id in &block.locals { - let other_local = &ctx.heap[*other_local_id]; - // Position check in case another variable with the same name - // is defined in a higher-level scope, but later than the scope - // in which the current variable resides. - if local.this != *other_local_id && - cur_relative_pos >= other_local.relative_pos_in_block && - local.identifier == other_local.identifier { - // Collision within this scope + let mut scope = &ctx.heap[target_scope_id]; + let mut cur_relative_pos = scope.relative_pos_in_parent; + while let Some(scope_parent_id) = scope.parent { + scope = &ctx.heap[scope_parent_id]; + + // Check for collisions + for variable_id in scope.variables.iter().copied() { + let existing_variable = &ctx.heap[variable_id]; + if existing_variable.identifier == new_variable.identifier && + existing_variable.this != new_variable_id && + cur_relative_pos >= existing_variable.relative_pos_in_parent { return Err( ParseError::new_error_str_at_span( - &ctx.module().source, local.identifier.span, "Local variable name conflicts with another variable" + &ctx.module().source, new_variable.identifier.span, "Local variable name conflicts with another variable" ).with_info_str_at_span( - &ctx.module().source, other_local.identifier.span, "Previous variable is found here" + &ctx.module().source, existing_variable.identifier.span, "Previous variable is found here" ) ); } } - scope = block.scope_node.parent; - cur_relative_pos = block.scope_node.relative_pos_in_parent; + cur_relative_pos = scope.relative_pos_in_parent; } // No collisions in any of the parent scope, attempt to add to scope - self.checked_at_single_scope_add_local(ctx, target_scope, target_relative_pos, id) + self.checked_at_single_scope_add_local(ctx, target_scope_id, target_relative_pos, new_variable_id) } /// Adds a local variable to the specified scope. Will check the specified /// scope for variable conflicts and the symbol table for global conflicts. /// Will NOT check parent scopes of the specified scope. fn checked_at_single_scope_add_local( - &mut self, ctx: &mut Ctx, scope: Scope, relative_pos: i32, id: VariableId + &mut self, ctx: &mut Ctx, scope_id: ScopeId, relative_pos: i32, new_variable_id: VariableId ) -> Result<(), ParseError> { // Check the symbol table for conflicts { - let cur_scope = SymbolScope::Definition(self.def_type.definition_id()); - let ident = &ctx.heap[id].identifier; + let cur_scope = SymbolScope::Definition(self.proc_id.upcast()); + let ident = &ctx.heap[new_variable_id].identifier; if let Some(symbol) = ctx.symbols.get_symbol_by_name(cur_scope, &ident.value.as_bytes()) { return Err(ParseError::new_error_str_at_span( &ctx.module().source, ident.span, @@ -1694,32 +1519,31 @@ impl PassValidationLinking { } // Check the specified scope for conflicts - let local = &ctx.heap[id]; + let new_variable = &ctx.heap[new_variable_id]; + let scope = &ctx.heap[scope_id]; - debug_assert!(scope.is_block()); - let block = &ctx.heap[scope.to_block()]; - for other_local_id in &block.locals { - let other_local = &ctx.heap[*other_local_id]; - if local.this != other_local.this && + for variable_id in scope.variables.iter().copied() { + let old_variable = &ctx.heap[variable_id]; + if new_variable.this != old_variable.this && // relative_pos >= other_local.relative_pos_in_block && - local.identifier == other_local.identifier { + new_variable.identifier == old_variable.identifier { // Collision return Err( ParseError::new_error_str_at_span( - &ctx.module().source, local.identifier.span, "Local variable name conflicts with another variable" + &ctx.module().source, new_variable.identifier.span, "Local variable name conflicts with another variable" ).with_info_str_at_span( - &ctx.module().source, other_local.identifier.span, "Previous variable is found here" + &ctx.module().source, old_variable.identifier.span, "Previous variable is found here" ) ); } } // No collisions - let block = &mut ctx.heap[scope.to_block()]; - block.locals.push(id); + let scope = &mut ctx.heap[scope_id]; + scope.variables.push(new_variable_id); - let local = &mut ctx.heap[id]; - local.relative_pos_in_block = relative_pos; + let variable = &mut ctx.heap[new_variable_id]; + variable.relative_pos_in_parent = relative_pos; Ok(()) } @@ -1727,85 +1551,66 @@ impl PassValidationLinking { /// Finds a variable in the visitor's scope that must appear before the /// specified relative position within that block. fn find_variable(&self, ctx: &Ctx, mut relative_pos: i32, identifier: &Identifier) -> Option { - debug_assert!(self.cur_scope.is_block()); + let mut scope_id = self.cur_scope; - // No need to use iterator over namespaces if here - let mut scope = &self.cur_scope; - loop { - debug_assert!(scope.is_block()); - let block = &ctx.heap[scope.to_block()]; + // Check if we can find the variable in the current scope + let scope = &ctx.heap[scope_id]; - for local_id in &block.locals { - let local = &ctx.heap[*local_id]; + for variable_id in scope.variables.iter().copied() { + let variable = &ctx.heap[variable_id]; - if local.relative_pos_in_block < relative_pos && identifier == &local.identifier { - return Some(*local_id); + if variable.relative_pos_in_parent < relative_pos && identifier == &variable.identifier { + return Some(variable_id); } } - scope = &block.scope_node.parent; - if !scope.is_block() { - // Definition scope, need to check arguments to definition - match scope { - Scope::Definition(definition_id) => { - let definition = &ctx.heap[*definition_id]; - for parameter_id in definition.parameters() { - let parameter = &ctx.heap[*parameter_id]; - if identifier == ¶meter.identifier { - return Some(*parameter_id); - } - } - }, - _ => unreachable!(), - } - - // Variable could not be found - return None - } else { - relative_pos = block.scope_node.relative_pos_in_parent; + // Could not find variable, move to parent scope and try again + if scope.parent.is_none() { + return None; } + + scope_id = scope.parent.unwrap(); + relative_pos = scope.relative_pos_in_parent; } } /// Adds a particular label to the current scope. Will return an error if /// there is another label with the same name visible in the current scope. - fn checked_add_label(&mut self, ctx: &mut Ctx, relative_pos: i32, in_sync: SynchronousStatementId, id: LabeledStatementId) -> Result<(), ParseError> { - debug_assert!(self.cur_scope.is_block()); - + fn checked_add_label(&mut self, ctx: &mut Ctx, relative_pos: i32, in_sync: SynchronousStatementId, new_label_id: LabeledStatementId) -> Result<(), ParseError> { // Make sure label is not defined within the current scope or any of the // parent scope. - let label = &mut ctx.heap[id]; - label.relative_pos_in_block = relative_pos; - label.in_sync = in_sync; + let new_label = &mut ctx.heap[new_label_id]; + new_label.relative_pos_in_parent = relative_pos; + new_label.in_sync = in_sync; - let label = &ctx.heap[id]; - let mut scope = &self.cur_scope; + let new_label = &ctx.heap[new_label_id]; + let mut scope_id = self.cur_scope; loop { - debug_assert!(scope.is_block(), "scope is not a block"); - let block = &ctx.heap[scope.to_block()]; - for other_label_id in &block.labels { - let other_label = &ctx.heap[*other_label_id]; - if other_label.label == label.label { + let scope = &ctx.heap[scope_id]; + for existing_label_id in scope.labels.iter().copied() { + let existing_label = &ctx.heap[existing_label_id]; + if existing_label.label == new_label.label { // Collision return Err(ParseError::new_error_str_at_span( - &ctx.module().source, label.label.span, "label name is used more than once" + &ctx.module().source, new_label.label.span, "label name is used more than once" ).with_info_str_at_span( - &ctx.module().source, other_label.label.span, "the other label is found here" + &ctx.module().source, existing_label.label.span, "the other label is found here" )); } } - scope = &block.scope_node.parent; - if !scope.is_block() { + if scope.parent.is_none() { break; } + + scope_id = scope.parent.unwrap(); } // No collisions - let block = &mut ctx.heap[self.cur_scope.to_block()]; - block.labels.push(id); + let scope = &mut ctx.heap[self.cur_scope]; + scope.labels.push(new_label_id); Ok(()) } @@ -1813,60 +1618,57 @@ impl PassValidationLinking { /// Finds a particular labeled statement by its identifier. Once found it /// will make sure that the target label does not skip over any variable /// declarations within the scope in which the label was found. - fn find_label(mut scope: Scope, ctx: &Ctx, identifier: &Identifier) -> Result { - debug_assert!(scope.is_block()); - + fn find_label(mut scope_id: ScopeId, ctx: &Ctx, identifier: &Identifier) -> Result { loop { - debug_assert!(scope.is_block(), "scope is not a block"); - let relative_scope_pos = ctx.heap[scope.to_block()].scope_node.relative_pos_in_parent; + let scope = &ctx.heap[scope_id]; + let relative_scope_pos = scope.relative_pos_in_parent; - let block = &ctx.heap[scope.to_block()]; - for label_id in &block.labels { - let label = &ctx.heap[*label_id]; + for label_id in scope.labels.iter().copied() { + let label = &ctx.heap[label_id]; if label.label == *identifier { - for local_id in &block.locals { + // Found the target label, now make sure that the jump to + // the label doesn't imply a skipped variable declaration + for variable_id in scope.variables.iter().copied() { // TODO: Better to do this in control flow analysis, it // is legal to skip over a variable declaration if it - // is not actually being used. I might be missing - // something here when laying out the bytecode... - let local = &ctx.heap[*local_id]; - if local.relative_pos_in_block > relative_scope_pos && local.relative_pos_in_block < label.relative_pos_in_block { + // is not actually being used. + let variable = &ctx.heap[variable_id]; + if variable.relative_pos_in_parent > relative_scope_pos && variable.relative_pos_in_parent < label.relative_pos_in_parent { return Err( ParseError::new_error_str_at_span(&ctx.module().source, identifier.span, "this target label skips over a variable declaration") .with_info_str_at_span(&ctx.module().source, label.label.span, "because it jumps to this label") - .with_info_str_at_span(&ctx.module().source, local.identifier.span, "which skips over this variable") + .with_info_str_at_span(&ctx.module().source, variable.identifier.span, "which skips over this variable") ); } } - return Ok(*label_id); + return Ok(label_id); } } - scope = block.scope_node.parent; - if !scope.is_block() { + if scope.parent.is_none() { return Err(ParseError::new_error_str_at_span( &ctx.module().source, identifier.span, "could not find this label" )); } + scope_id = scope.parent.unwrap(); } } - /// This function will check if the provided while statement ID has a block - /// statement that is one of our current parents. - fn has_parent_while_scope(mut scope: Scope, ctx: &Ctx, id: WhileStatementId) -> bool { - let while_stmt = &ctx.heap[id]; + /// This function will check if the provided scope has a parent that belongs + /// to a while statement. + fn scope_is_nested_in_while_statement(mut scope_id: ScopeId, ctx: &Ctx, expected_while_id: WhileStatementId) -> bool { + let while_stmt = &ctx.heap[expected_while_id]; + loop { - debug_assert!(scope.is_block()); - let block = scope.to_block(); - if while_stmt.body == block { + let scope = &ctx.heap[scope_id]; + if scope.this == while_stmt.scope { return true; } - let block = &ctx.heap[block]; - scope = block.scope_node.parent; - if !scope.is_block() { - return false; + match scope.parent { + Some(new_scope_id) => scope_id = new_scope_id, + None => return false, // walked all the way up, not encountering the while statement } } } @@ -1885,9 +1687,10 @@ impl PassValidationLinking { // Make sure break target is a while statement let target = &ctx.heap[target_id]; if let Statement::While(target_stmt) = &ctx.heap[target.body] { - // Even though we have a target while statement, the break might not be - // present underneath this particular labeled while statement - if !Self::has_parent_while_scope(control_flow.in_scope, ctx, target_stmt.this) { + // Even though we have a target while statement, the control + // flow statement might not be present underneath this + // particular labeled while statement. + if !Self::scope_is_nested_in_while_statement(control_flow.in_scope, ctx, target_stmt.this) { return Err(ParseError::new_error_str_at_span( &ctx.module().source, label.span, "break statement is not nested under the target label's while statement" ).with_info_str_at_span( diff --git a/src/protocol/parser/type_table.rs b/src/protocol/parser/type_table.rs index de0d8cc5461dc9dee9fa87b537774f81a014d8e8..916ebf1f17a37369911fd8167247ac17c60a032f 100644 --- a/src/protocol/parser/type_table.rs +++ b/src/protocol/parser/type_table.rs @@ -36,8 +36,10 @@ * layout. */ -use std::fmt::{Formatter, Result as FmtResult}; +// Programmer note: deduplication of types is currently disabled, see the +// @Deduplication key. Tests might fail when it is re-enabled. use std::collections::HashMap; +use std::hash::{Hash, Hasher}; use crate::protocol::ast::*; use crate::protocol::parser::symbol_table::SymbolScope; @@ -48,40 +50,6 @@ use crate::protocol::parser::*; // Defined Types //------------------------------------------------------------------------------ -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum TypeClass { - Enum, - Union, - Struct, - Function, - Component -} - -impl TypeClass { - pub(crate) fn display_name(&self) -> &'static str { - match self { - TypeClass::Enum => "enum", - TypeClass::Union => "union", - TypeClass::Struct => "struct", - TypeClass::Function => "function", - TypeClass::Component => "component", - } - } - - pub(crate) fn is_data_type(&self) -> bool { - match self { - TypeClass::Enum | TypeClass::Union | TypeClass::Struct => true, - TypeClass::Function | TypeClass::Component => false, - } - } -} - -impl std::fmt::Display for TypeClass { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - write!(f, "{}", self.display_name()) - } -} - /// Struct wrapping around a potentially polymorphic type. If the type does not /// have any polymorphic arguments then it will not have any monomorphs and /// `is_polymorph` will be set to `false`. A type with polymorphic arguments @@ -101,46 +69,44 @@ pub enum DefinedTypeVariant { Enum(EnumType), Union(UnionType), Struct(StructType), - Function(FunctionType), - Component(ComponentType) + Procedure(ProcedureType), } impl DefinedTypeVariant { - pub(crate) fn type_class(&self) -> TypeClass { + pub(crate) fn is_data_type(&self) -> bool { + use DefinedTypeVariant as DTV; + match self { - DefinedTypeVariant::Enum(_) => TypeClass::Enum, - DefinedTypeVariant::Union(_) => TypeClass::Union, - DefinedTypeVariant::Struct(_) => TypeClass::Struct, - DefinedTypeVariant::Function(_) => TypeClass::Function, - DefinedTypeVariant::Component(_) => TypeClass::Component + DTV::Struct(_) | DTV::Enum(_) | DTV::Union(_) => return true, + DTV::Procedure(_) => return false, } } pub(crate) fn as_struct(&self) -> &StructType { match self { DefinedTypeVariant::Struct(v) => v, - _ => unreachable!("Cannot convert {} to struct variant", self.type_class()) + _ => unreachable!() } } pub(crate) fn as_enum(&self) -> &EnumType { match self { DefinedTypeVariant::Enum(v) => v, - _ => unreachable!("Cannot convert {} to enum variant", self.type_class()) + _ => unreachable!() } } pub(crate) fn as_union(&self) -> &UnionType { match self { DefinedTypeVariant::Union(v) => v, - _ => unreachable!("Cannot convert {} to union variant", self.type_class()) + _ => unreachable!() } } } pub struct PolymorphicVariable { - identifier: Identifier, - is_in_use: bool, // a polymorphic argument may be defined, but not used by the type definition + pub(crate) identifier: Identifier, + pub(crate) is_in_use: bool, // a polymorphic argument may be defined, but not used by the type definition } /// `EnumType` is the classical C/C++ enum type. It has various variants with @@ -196,19 +162,14 @@ pub struct StructField { pub parser_type: ParserType, } -/// `FunctionType` is what you expect it to be: a particular function's -/// signature. -pub struct FunctionType { - pub return_types: Vec, - pub arguments: Vec, +/// `ProcedureType` is the signature of a procedure/component +pub struct ProcedureType { + pub kind: ProcedureKind, + pub return_type: Option, + pub arguments: Vec, } -pub struct ComponentType { - pub variant: ComponentVariant, - pub arguments: Vec, -} - -pub struct FunctionArgument { +pub struct ProcedureArgument { identifier: Identifier, parser_type: ParserType, } @@ -224,22 +185,15 @@ pub struct MonomorphExpression { // monomorph index for polymorphic function calls or literals. Negative // values are never used, but used to catch programming errors. pub(crate) field_or_monomorph_idx: i32, + pub(crate) type_id: TypeId, } //------------------------------------------------------------------------------ // Type monomorph storage //------------------------------------------------------------------------------ -/// Generic monomorph has a specific concrete type, a size and an alignment. -/// Extra data is in the `MonomorphVariant` per kind of type. -pub(crate) struct TypeMonomorph { - pub concrete_type: ConcreteType, - pub size: usize, - pub alignment: usize, - pub variant: MonomorphVariant, -} - -pub(crate) enum MonomorphVariant { +pub(crate) enum MonoTypeVariant { + Builtin, // no extra data, added manually in compiler initialization code Enum, // no extra data Struct(StructMonomorph), Union(UnionMonomorph), @@ -247,45 +201,45 @@ pub(crate) enum MonomorphVariant { Tuple(TupleMonomorph), } -impl MonomorphVariant { +impl MonoTypeVariant { fn as_struct_mut(&mut self) -> &mut StructMonomorph { match self { - MonomorphVariant::Struct(v) => v, + MonoTypeVariant::Struct(v) => v, _ => unreachable!(), } } pub(crate) fn as_union(&self) -> &UnionMonomorph { match self { - MonomorphVariant::Union(v) => v, + MonoTypeVariant::Union(v) => v, _ => unreachable!(), } } fn as_union_mut(&mut self) -> &mut UnionMonomorph { match self { - MonomorphVariant::Union(v) => v, + MonoTypeVariant::Union(v) => v, _ => unreachable!(), } } fn as_tuple_mut(&mut self) -> &mut TupleMonomorph { match self { - MonomorphVariant::Tuple(v) => v, + MonoTypeVariant::Tuple(v) => v, _ => unreachable!(), } } - fn as_procedure(&self) -> &ProcedureMonomorph { + pub(crate) fn as_procedure(&self) -> &ProcedureMonomorph { match self { - MonomorphVariant::Procedure(v) => v, + MonoTypeVariant::Procedure(v) => v, _ => unreachable!(), } } fn as_procedure_mut(&mut self) -> &mut ProcedureMonomorph { match self { - MonomorphVariant::Procedure(v) => v, + MonoTypeVariant::Procedure(v) => v, _ => unreachable!(), } } @@ -297,7 +251,8 @@ pub struct StructMonomorph { } pub struct StructMonomorphField { - pub concrete_type: ConcreteType, + pub type_id: TypeId, + concrete_type: ConcreteType, pub size: usize, pub alignment: usize, pub offset: usize, @@ -323,7 +278,8 @@ pub struct UnionMonomorphVariant { } pub struct UnionMonomorphEmbedded { - pub concrete_type: ConcreteType, + pub type_id: TypeId, + concrete_type: ConcreteType, // Note that the meaning of the offset (and alignment) depend on whether or // not the variant lives on the stack/heap. If it lives on the stack then // they refer to the offset from the start of the union value (so the first @@ -338,9 +294,8 @@ pub struct UnionMonomorphEmbedded { /// Procedure (functions and components of all possible types) monomorph. Also /// stores the expression type data from the typechecking/inferencing pass. pub struct ProcedureMonomorph { - // Expression data for one particular monomorph - pub arg_types: Vec, - pub expr_data: Vec, + pub monomorph_index: u32, + pub builtin: bool, } /// Tuple monomorph. Again a kind of exception because one cannot define a named @@ -351,197 +306,238 @@ pub struct TupleMonomorph { } pub struct TupleMonomorphMember { - pub concrete_type: ConcreteType, + pub type_id: TypeId, + concrete_type: ConcreteType, pub size: usize, pub alignment: usize, pub offset: usize, } -/// Key used to perform lookups in the monomorph table. It computes a hash of -/// the type while not taking the unused polymorphic variables of the base type -/// into account (e.g. `struct Foo{ A field }`, here `B` is an unused -/// polymorphic variable). -struct MonomorphKey { - parts: Vec, - in_use: Vec, // TODO: @Performance, limit num args and use two `u64` as bitflags or something +/// Generic unique type ID. Every monomorphed type and every non-polymorphic +/// type will have one of these associated with it. +#[derive(Debug, Clone, Copy, PartialEq)] +pub struct TypeId(i64); + +impl TypeId { + pub(crate) fn new_invalid() -> Self { + return Self(-1); + } } -use std::hash::*; +/// A monomorphed type (or non-polymorphic type's) memory layout and information +/// regarding associated types (like a struct's field type). +pub struct MonoType { + pub type_id: TypeId, + pub concrete_type: ConcreteType, + pub size: usize, + pub alignment: usize, + pub(crate) variant: MonoTypeVariant +} -impl Hash for MonomorphKey { - fn hash(&self, state: &mut H) { - // if `in_use` is empty, then we may assume the type is not polymorphic - // (or all types are in use) - if self.in_use.is_empty() { - self.parts.hash(state); +impl MonoType { + #[inline] + fn new_empty(type_id: TypeId, concrete_type: ConcreteType, variant: MonoTypeVariant) -> Self { + return Self { + type_id, concrete_type, + size: 0, + alignment: 0, + variant, + } + } + + /// Little internal helper function as a reminder: if alignment is 0, then + /// the size/alignment are not actually computed yet! + #[inline] + fn get_size_alignment(&self) -> Option<(usize, usize)> { + if self.alignment == 0 { + return None } else { - // type is polymorphic - self.parts[0].hash(state); - - // note: hash is computed in a unique way, because practically - // speaking `in_use` is fixed per base type. So we cannot have the - // same base type (hence: a type with the same DefinitionId) with - // different different polymorphic variables in use. - let mut in_use_index = 0; - for section in ConcreteTypeIter::new(self.parts.as_slice(), 0) { - if self.in_use[in_use_index] { - section.hash(state); - } - in_use_index += 1; - } + return Some((self.size, self.alignment)); } } } -impl PartialEq for MonomorphKey { - fn eq(&self, other: &Self) -> bool { - if self.in_use.is_empty() { - let temp_result = self.parts == other.parts; - return temp_result; - } else { - // Outer type does not match - if self.parts[0] != other.parts[0] { - return false; - } +/// Special structure that acts like the lookup key for `ConcreteType` instances +/// that have already been added to the type table before. +#[derive(Clone)] +struct MonoSearchKey { + // Uses bitflags to denote when parts between search keys should match and + // whether they should be checked. Needs to have a system like this to + // accommodate tuples. + parts: Vec<(u8, ConcreteTypePart)>, + change_bit: u8, +} - debug_assert_eq!(self.parts[0].num_embedded() as usize, self.in_use.len()); - let mut iter_self = ConcreteTypeIter::new(self.parts.as_slice(), 0); - let mut iter_other = ConcreteTypeIter::new(other.parts.as_slice(), 0); - let mut index = 0; - while let Some(section_self) = iter_self.next() { - let section_other = iter_other.next().unwrap(); - let in_use = self.in_use[index]; - index += 1; +impl MonoSearchKey { + const KEY_IN_USE: u8 = 0x01; + const KEY_CHANGE_BIT: u8 = 0x02; - if !in_use { - continue; - } + fn with_capacity(capacity: usize) -> Self { + return MonoSearchKey{ + parts: Vec::with_capacity(capacity), + change_bit: 0, + }; + } - if section_self != section_other { - return false; - } - } + /// Sets the search key based on a single concrete type and its polymorphic + /// variables. + fn set(&mut self, concrete_type_parts: &[ConcreteTypePart], poly_var_in_use: &[PolymorphicVariable]) { + self.set_top_type(concrete_type_parts[0]); - return true; + let mut poly_var_index = 0; + for subtype in ConcreteTypeIter::new(concrete_type_parts, 0) { + let in_use = poly_var_in_use[poly_var_index].is_in_use; + poly_var_index += 1; + self.push_subtype(subtype, in_use); } - } -} -impl Eq for MonomorphKey {} - -use std::cell::UnsafeCell; + debug_assert_eq!(poly_var_index, poly_var_in_use.len()); + } -/// Lookup table for monomorphs. Wrapped in a special struct because we don't -/// want to allocate for each lookup (what we really want is a HashMap that -/// exposes its CompareFn and HashFn, but whatevs). -pub(crate) struct MonomorphTable { - lookup: HashMap, // indexes into `monomorphs` - pub(crate) monomorphs: Vec, - // We use an UnsafeCell because this is only used internally per call to - // `get_monomorph_index` calls. This is safe because `&TypeMonomorph`s - // retrieved for this class remain valid when the key is mutated and the - // type table is not multithreaded. - // - // I added this because we don't want to allocate for each lookup, hence we - // need a reusable `key` internal to this class. This in turn makes - // `get_monomorph_index` a mutable call. Now the code that calls this - // function (even though we're not mutating the table!) needs a lot of extra - // boilerplate. I opted for the `UnsafeCell` instead of the boilerplate. - key: UnsafeCell, -} + /// Starts setting the search key based on an initial top-level type, + /// programmer must call `push_subtype` the appropriate number of times + /// after calling this function + fn set_top_type(&mut self, type_part: ConcreteTypePart) { + self.parts.clear(); + self.parts.push((Self::KEY_IN_USE, type_part)); + self.change_bit = Self::KEY_CHANGE_BIT; + } -// TODO: Clean this up: somehow prevent the `key`, but also do not allocate for -// each "get_monomorph_index" -unsafe impl Send for MonomorphTable{} -unsafe impl Sync for MonomorphTable{} + fn push_subtype(&mut self, concrete_type: &[ConcreteTypePart], in_use: bool) { + let flag = self.change_bit | (if in_use { Self::KEY_IN_USE } else { 0 }); -impl MonomorphTable { - fn new() -> Self { - return Self { - lookup: HashMap::with_capacity(256), - monomorphs: Vec::with_capacity(256), - key: UnsafeCell::new(MonomorphKey{ - parts: Vec::with_capacity(32), - in_use: Vec::with_capacity(32), - }), + for part in concrete_type { + self.parts.push((flag, *part)); } + self.change_bit ^= Self::KEY_CHANGE_BIT; } - fn insert_with_zero_size_and_alignment(&mut self, concrete_type: ConcreteType, in_use: &[PolymorphicVariable], variant: MonomorphVariant) -> i32 { - let key = MonomorphKey{ - parts: Vec::from(concrete_type.parts.as_slice()), - in_use: in_use.iter().map(|v| v.is_in_use).collect(), - }; - let index = self.monomorphs.len(); - let _result = self.lookup.insert(key, index as i32); - debug_assert!(_result.is_none()); // did not exist yet - self.monomorphs.push(TypeMonomorph{ - concrete_type, - size: 0, - alignment: 0, - variant, - }); + fn push_subtree(&mut self, concrete_type: &[ConcreteTypePart], poly_var_in_use: &[PolymorphicVariable]) { + self.parts.push((self.change_bit | Self::KEY_IN_USE, concrete_type[0])); + self.change_bit ^= Self::KEY_CHANGE_BIT; - return index as i32; + let mut poly_var_index = 0; + for subtype in ConcreteTypeIter::new(concrete_type, 0) { + let in_use = poly_var_in_use[poly_var_index].is_in_use; + poly_var_index += 1; + self.push_subtype(subtype, in_use); + } + + debug_assert_eq!(poly_var_index, poly_var_in_use.len()); } - fn get_monomorph_index(&self, parts: &[ConcreteTypePart], in_use: &[PolymorphicVariable]) -> Option { - let key = unsafe { - // Clear-and-extend to, at some point, prevent future allocations - let key = &mut *self.key.get(); - key.parts.clear(); - key.parts.extend_from_slice(parts); - key.in_use.clear(); - key.in_use.extend(in_use.iter().map(|v| v.is_in_use)); + // Utilities for hashing and comparison + fn find_end_index(&self, start_index: usize) -> usize { + // Check if we're already at the end + let mut index = start_index; + if index >= self.parts.len() { + return index; + } - &*key - }; + // Iterate until bit flips, or until at end + let expected_bit = self.parts[index].0 & Self::KEY_CHANGE_BIT; - match self.lookup.get(key) { - Some(index) => return Some(*index), - None => return None, + index += 1; + while index < self.parts.len() { + let current_bit = self.parts[index].0 & Self::KEY_CHANGE_BIT; + if current_bit != expected_bit { + return index; + } + + index += 1; } - } - #[inline] - fn get(&self, index: i32) -> &TypeMonomorph { - debug_assert!(index >= 0); - return &self.monomorphs[index as usize]; + return index; } +} - #[inline] - fn get_mut(&mut self, index: i32) -> &mut TypeMonomorph { - debug_assert!(index >= 0); - return &mut self.monomorphs[index as usize]; +impl Hash for MonoSearchKey { + fn hash(&self, state: &mut H) { + for index in 0..self.parts.len() { + let (_flags, part) = self.parts[index]; + // if flags & Self::KEY_IN_USE != 0 { @Deduplication + part.hash(state); + // } + } } +} - fn get_monomorph_size_alignment(&self, index: i32) -> Option<(usize, usize)> { - let monomorph = self.get(index); - if monomorph.size == 0 && monomorph.alignment == 0 { - // If both are zero, then we wish to mean: we haven't actually - // computed the size and alignment yet. So: - return None; - } else { - return Some((monomorph.size, monomorph.alignment)); +impl PartialEq for MonoSearchKey { + fn eq(&self, other: &Self) -> bool { + let mut self_index = 0; + let mut other_index = 0; + + while self_index < self.parts.len() && other_index < other.parts.len() { + // Retrieve part and flags + let (_self_bits, _) = self.parts[self_index]; + let (_other_bits, _) = other.parts[other_index]; + let self_in_use = true; // (self_bits & Self::KEY_IN_USE) != 0; @Deduplication + let other_in_use = true; // (other_bits & Self::KEY_IN_USE) != 0; @Deduplication + + // Determine ending indices + let self_end_index = self.find_end_index(self_index); + let other_end_index = other.find_end_index(other_index); + + if self_in_use == other_in_use { + if self_in_use { + // Both are in use, so both parts should be equal + let delta_self = self_end_index - self_index; + let delta_other = other_end_index - other_index; + if delta_self != delta_other { + // Both in use, but not of equal length, so the types + // cannot match + return false; + } + + for _ in 0..delta_self { + let (_, self_part) = self.parts[self_index]; + let (_, other_part) = other.parts[other_index]; + + if self_part != other_part { + return false; + } + + self_index += 1; + other_index += 1; + } + } else { + // Both not in use, so skip associated parts + self_index = self_end_index; + other_index = other_end_index; + } + } else { + // No agreement on importance of parts. This is practically + // impossible + unreachable!(); + } } + + // Everything matched, so if we're at the end of both arrays then we're + // certain that the two keys are equal. + return self_index == self.parts.len() && other_index == other.parts.len(); } } +impl Eq for MonoSearchKey{} + //------------------------------------------------------------------------------ // Type table //------------------------------------------------------------------------------ +const POLY_VARS_IN_USE: [PolymorphicVariable; 1] = [PolymorphicVariable{ identifier: Identifier::new_empty(InputSpan::new()), is_in_use: true }]; + // Programmer note: keep this struct free of dynamically allocated memory #[derive(Clone)] struct TypeLoopBreadcrumb { - monomorph_idx: i32, + type_id: TypeId, next_member: u32, next_embedded: u32, // for unions, the index into the variant's embedded types } +// Programmer note: keep this struct free of dynamically allocated memory #[derive(Clone)] struct MemoryBreadcrumb { - monomorph_idx: i32, + type_id: TypeId, next_member: u32, next_embedded: u32, first_size_alignment_idx: u32, @@ -561,27 +557,33 @@ enum MemoryLayoutResult { // TODO: @Optimize, initial memory-unoptimized implementation struct TypeLoopEntry { - monomorph_idx: i32, + type_id: TypeId, is_union: bool, } struct TypeLoop { - members: Vec + members: Vec, } +type DefinitionMap = HashMap; +type MonoTypeMap = HashMap; +type MonoTypeArray = Vec; + pub struct TypeTable { - /// Lookup from AST DefinitionId to a defined type. Also lookups for - /// concrete type to monomorphs - pub(crate) type_lookup: HashMap, - pub(crate) mono_lookup: MonomorphTable, - /// Breadcrumbs left behind while trying to find type loops. Also used to - /// determine sizes of types when all type loops are detected. + // Lookup from AST DefinitionId to a defined type. Also lookups for + // concrete type to monomorphs + pub(crate) definition_lookup: DefinitionMap, + mono_type_lookup: MonoTypeMap, + pub(crate) mono_types: MonoTypeArray, + mono_search_key: MonoSearchKey, + // Breadcrumbs left behind while trying to find type loops. Also used to + // determine sizes of types when all type loops are detected. type_loop_breadcrumbs: Vec, type_loops: Vec, - /// Stores all encountered types during type loop detection. Used afterwards - /// to iterate over all types in order to compute size/alignment. + // Stores all encountered types during type loop detection. Used afterwards + // to iterate over all types in order to compute size/alignment. encountered_types: Vec, - /// Breadcrumbs and temporary storage during memory layout computation. + // Breadcrumbs and temporary storage during memory layout computation. memory_layout_breadcrumbs: Vec, size_alignment_stack: Vec<(usize, usize)>, } @@ -590,8 +592,10 @@ impl TypeTable { /// Construct a new type table without any resolved types. pub(crate) fn new() -> Self { Self{ - type_lookup: HashMap::with_capacity(128), - mono_lookup: MonomorphTable::new(), + definition_lookup: HashMap::with_capacity(128), + mono_type_lookup: HashMap::with_capacity(128), + mono_types: Vec::with_capacity(128), + mono_search_key: MonoSearchKey::with_capacity(32), type_loop_breadcrumbs: Vec::with_capacity(32), type_loops: Vec::with_capacity(8), encountered_types: Vec::with_capacity(32), @@ -609,17 +613,17 @@ impl TypeTable { pub(crate) fn build_base_types(&mut self, modules: &mut [Module], ctx: &mut PassCtx) -> Result<(), ParseError> { // Make sure we're allowed to cast root_id to index into ctx.modules debug_assert!(modules.iter().all(|m| m.phase >= ModuleCompilationPhase::DefinitionsParsed)); - debug_assert!(self.type_lookup.is_empty()); + debug_assert!(self.definition_lookup.is_empty()); - if cfg!(debug_assertions) { + dbg_code!({ for (index, module) in modules.iter().enumerate() { debug_assert_eq!(index, module.root_id.index as usize); } - } + }); // Use context to guess hashmap size of the base types let reserve_size = ctx.heap.definitions.len(); - self.type_lookup.reserve(reserve_size); + self.definition_lookup.reserve(reserve_size); // Resolve all base types for definition_idx in 0..ctx.heap.definitions.len() { @@ -630,34 +634,34 @@ impl TypeTable { Definition::Enum(_) => self.build_base_enum_definition(modules, ctx, definition_id)?, Definition::Union(_) => self.build_base_union_definition(modules, ctx, definition_id)?, Definition::Struct(_) => self.build_base_struct_definition(modules, ctx, definition_id)?, - Definition::Function(_) => self.build_base_function_definition(modules, ctx, definition_id)?, - Definition::Component(_) => self.build_base_component_definition(modules, ctx, definition_id)?, + Definition::Procedure(_) => self.build_base_procedure_definition(modules, ctx, definition_id)?, } } - debug_assert_eq!(self.type_lookup.len(), reserve_size, "mismatch in reserved size of type table"); // NOTE: Temp fix for builtin functions + debug_assert_eq!(self.definition_lookup.len(), reserve_size, "mismatch in reserved size of type table"); for module in modules.iter_mut() { module.phase = ModuleCompilationPhase::TypesAddedToTable; } // Go through all types again, lay out all types that are not - // polymorphic. This might cause us to lay out types that are monomorphs - // of polymorphic types. + // polymorphic. This might cause us to lay out monomorphized polymorphs + // if these were member types of non-polymorphic types. for definition_idx in 0..ctx.heap.definitions.len() { let definition_id = ctx.heap.definitions.get_id(definition_idx); - let poly_type = self.type_lookup.get(&definition_id).unwrap(); + let poly_type = self.definition_lookup.get(&definition_id).unwrap(); - if !poly_type.definition.type_class().is_data_type() || !poly_type.poly_vars.is_empty() { + if !poly_type.definition.is_data_type() || !poly_type.poly_vars.is_empty() { continue; } // If here then the type is a data type without polymorphic // variables, but we might have instantiated it already, so: let concrete_parts = [ConcreteTypePart::Instance(definition_id, 0)]; - let mono_index = self.mono_lookup.get_monomorph_index(&concrete_parts, &[]); - if mono_index.is_none() { + self.mono_search_key.set(&concrete_parts, &[]); + let type_id = self.mono_type_lookup.get(&self.mono_search_key); + if type_id.is_none() { self.detect_and_resolve_type_loops_for( - modules, ctx.heap, + modules, ctx.heap, ctx.arch, ConcreteType{ parts: vec![ConcreteTypePart::Instance(definition_id, 0)] }, @@ -675,37 +679,26 @@ impl TypeTable { /// an option anyway #[inline] pub(crate) fn get_base_definition(&self, definition_id: &DefinitionId) -> Option<&DefinedType> { - self.type_lookup.get(&definition_id) + self.definition_lookup.get(&definition_id) } /// Returns the index into the monomorph type array if the procedure type /// already has a (reserved) monomorph. - /// FIXME: This really shouldn't be called from within the runtime. See UnsafeCell in MonomorphTable - #[inline] - pub(crate) fn get_procedure_monomorph_index(&self, definition_id: &DefinitionId, type_parts: &[ConcreteTypePart]) -> Option { - let base_type = self.type_lookup.get(definition_id).unwrap(); - return self.mono_lookup.get_monomorph_index(type_parts, &base_type.poly_vars); - } - #[inline] - pub(crate) fn get_monomorph(&self, monomorph_index: i32) -> &TypeMonomorph { - return self.mono_lookup.get(monomorph_index); + pub(crate) fn get_procedure_monomorph_type_id(&self, definition_id: &DefinitionId, type_parts: &[ConcreteTypePart]) -> Option { + // Cannot use internal search key due to mutability issues. But this + // method should end up being deprecated at some point anyway. + debug_assert_eq!(get_concrete_type_definition(type_parts).unwrap(), *definition_id); + let base_type = self.definition_lookup.get(definition_id).unwrap(); + let mut search_key = MonoSearchKey::with_capacity(type_parts.len()); + search_key.set(type_parts, &base_type.poly_vars); + + return self.mono_type_lookup.get(&search_key).copied(); } - /// Returns a mutable reference to a procedure's monomorph expression data. - /// Used by typechecker to fill in previously reserved type information #[inline] - pub(crate) fn get_procedure_monomorph_mut(&mut self, monomorph_index: i32) -> &mut ProcedureMonomorph { - debug_assert!(monomorph_index >= 0); - let monomorph = self.mono_lookup.get_mut(monomorph_index); - return monomorph.variant.as_procedure_mut(); - } - - #[inline] - pub(crate) fn get_procedure_monomorph(&self, monomorph_index: i32) -> &ProcedureMonomorph { - debug_assert!(monomorph_index >= 0); - let monomorph = self.mono_lookup.get(monomorph_index); - return monomorph.variant.as_procedure(); + pub(crate) fn get_monomorph(&self, type_id: TypeId) -> &MonoType { + return &self.mono_types[type_id.0 as usize]; } /// Reserves space for a monomorph of a polymorphic procedure. The index @@ -713,40 +706,78 @@ impl TypeTable { /// monomorph may NOT exist yet (because the reservation implies that we're /// going to be performing typechecking on it, and we don't want to /// check the same monomorph twice) - pub(crate) fn reserve_procedure_monomorph_index(&mut self, definition_id: &DefinitionId, concrete_type: ConcreteType) -> i32 { - let base_type = self.type_lookup.get_mut(definition_id).unwrap(); - let mono_index = self.mono_lookup.insert_with_zero_size_and_alignment( - concrete_type, &base_type.poly_vars, MonomorphVariant::Procedure(ProcedureMonomorph{ - arg_types: Vec::new(), - expr_data: Vec::new(), + pub(crate) fn reserve_procedure_monomorph_type_id(&mut self, definition_id: &DefinitionId, concrete_type: ConcreteType, monomorph_index: u32) -> TypeId { + debug_assert_eq!(get_concrete_type_definition(&concrete_type.parts).unwrap(), *definition_id); + let type_id = TypeId(self.mono_types.len() as i64); + let base_type = self.definition_lookup.get_mut(definition_id).unwrap(); + self.mono_search_key.set(&concrete_type.parts, &base_type.poly_vars); + + debug_assert!(!self.mono_type_lookup.contains_key(&self.mono_search_key)); + self.mono_type_lookup.insert(self.mono_search_key.clone(), type_id); + self.mono_types.push(MonoType::new_empty(type_id, concrete_type, MonoTypeVariant::Procedure(ProcedureMonomorph{ + monomorph_index, + builtin: false, + }))); + + return type_id; + } + + /// Adds a builtin type to the type table. As this is only called by the + /// compiler during setup we assume it cannot fail. + pub(crate) fn add_builtin_data_type(&mut self, concrete_type: ConcreteType, poly_vars: &[PolymorphicVariable], size: usize, alignment: usize) -> TypeId { + self.mono_search_key.set(&concrete_type.parts, poly_vars); + debug_assert!(!self.mono_type_lookup.contains_key(&self.mono_search_key)); + debug_assert_ne!(alignment, 0); + let type_id = TypeId(self.mono_types.len() as i64); + self.mono_type_lookup.insert(self.mono_search_key.clone(), type_id); + self.mono_types.push(MonoType{ + type_id, + concrete_type, + size, + alignment, + variant: MonoTypeVariant::Builtin, + }); + + return type_id; + } + + /// Adds a builtin procedure to the type table. + pub(crate) fn add_builtin_procedure_type(&mut self, concrete_type: ConcreteType, poly_vars: &[PolymorphicVariable]) -> TypeId { + self.mono_search_key.set(&concrete_type.parts, poly_vars); + debug_assert!(!self.mono_type_lookup.contains_key(&self.mono_search_key)); + let type_id = TypeId(self.mono_types.len() as i64); + self.mono_type_lookup.insert(self.mono_search_key.clone(), type_id); + self.mono_types.push(MonoType{ + type_id, + concrete_type, + size: 0, + alignment: 0, + variant: MonoTypeVariant::Procedure(ProcedureMonomorph{ + monomorph_index: u32::MAX, + builtin: true, }) - ); + }); - return mono_index; + return type_id; } - /// Adds a datatype polymorph to the type table. Will not add the - /// monomorph if it is already present, or if the type's polymorphic - /// variables are all unused. - /// TODO: Fix signature - pub(crate) fn add_data_monomorph( - &mut self, modules: &[Module], heap: &Heap, arch: &TargetArch, definition_id: DefinitionId, concrete_type: ConcreteType - ) -> Result { - debug_assert_eq!(definition_id, get_concrete_type_definition(&concrete_type)); - - // Check if the monomorph already exists - let poly_type = self.type_lookup.get_mut(&definition_id).unwrap(); - if let Some(idx) = self.mono_lookup.get_monomorph_index(&concrete_type.parts, &poly_type.poly_vars) { - return Ok(idx); + /// Adds a monomorphed type to the type table. If it already exists then the + /// previous entry will be used. + pub(crate) fn add_monomorphed_type( + &mut self, modules: &[Module], heap: &Heap, arch: &TargetArch, concrete_type: ConcreteType + ) -> Result { + // Check if the concrete type was already added + Self::set_search_key_to_type(&mut self.mono_search_key, &self.definition_lookup, &concrete_type.parts); + if let Some(type_id) = self.mono_type_lookup.get(&self.mono_search_key) { + return Ok(*type_id); } - // Doesn't exist, so instantiate a monomorph and determine its memory - // layout. - self.detect_and_resolve_type_loops_for(modules, heap, concrete_type)?; - let mono_idx = self.encountered_types[0].monomorph_idx; + // Concrete type needs to be added + self.detect_and_resolve_type_loops_for(modules, heap, arch, concrete_type)?; + let type_id = self.encountered_types[0].type_id; self.lay_out_memory_for_encountered_types(arch); - return Ok(mono_idx as i32); + return Ok(type_id); } //-------------------------------------------------------------------------- @@ -755,7 +786,7 @@ impl TypeTable { /// Builds the base type for an enum. Will not compute byte sizes fn build_base_enum_definition(&mut self, modules: &[Module], ctx: &mut PassCtx, definition_id: DefinitionId) -> Result<(), ParseError> { - debug_assert!(!self.type_lookup.contains_key(&definition_id), "base enum already built"); + debug_assert!(!self.definition_lookup.contains_key(&definition_id), "base enum already built"); let definition = ctx.heap[definition_id].as_enum(); let root_id = definition.defined_in; @@ -807,7 +838,7 @@ impl TypeTable { Self::check_poly_args_collision(modules, ctx, root_id, &definition.poly_vars)?; let poly_vars = Self::create_polymorphic_variables(&definition.poly_vars); - self.type_lookup.insert(definition_id, DefinedType { + self.definition_lookup.insert(definition_id, DefinedType { ast_root: root_id, ast_definition: definition_id, definition: DefinedTypeVariant::Enum(EnumType{ @@ -827,7 +858,7 @@ impl TypeTable { /// Builds the base type for a union. Will compute byte sizes. fn build_base_union_definition(&mut self, modules: &[Module], ctx: &mut PassCtx, definition_id: DefinitionId) -> Result<(), ParseError> { - debug_assert!(!self.type_lookup.contains_key(&definition_id), "base union already built"); + debug_assert!(!self.definition_lookup.contains_key(&definition_id), "base union already built"); let definition = ctx.heap[definition_id].as_union(); let root_id = definition.defined_in; @@ -872,7 +903,7 @@ impl TypeTable { let is_polymorph = poly_vars.iter().any(|arg| arg.is_in_use); - self.type_lookup.insert(definition_id, DefinedType{ + self.definition_lookup.insert(definition_id, DefinedType{ ast_root: root_id, ast_definition: definition_id, definition: DefinedTypeVariant::Union(UnionType{ variants, tag_type, tag_size }), @@ -885,7 +916,7 @@ impl TypeTable { /// Builds base struct type. Will not compute byte sizes. fn build_base_struct_definition(&mut self, modules: &[Module], ctx: &mut PassCtx, definition_id: DefinitionId) -> Result<(), ParseError> { - debug_assert!(!self.type_lookup.contains_key(&definition_id), "base struct already built"); + debug_assert!(!self.definition_lookup.contains_key(&definition_id), "base struct already built"); let definition = ctx.heap[definition_id].as_struct(); let root_id = definition.defined_in; @@ -917,7 +948,7 @@ impl TypeTable { let is_polymorph = poly_vars.iter().any(|arg| arg.is_in_use); - self.type_lookup.insert(definition_id, DefinedType{ + self.definition_lookup.insert(definition_id, DefinedType{ ast_root: root_id, ast_definition: definition_id, definition: DefinedTypeVariant::Struct(StructType{ fields }), @@ -928,15 +959,14 @@ impl TypeTable { return Ok(()) } - /// Builds base function type. - fn build_base_function_definition(&mut self, modules: &[Module], ctx: &mut PassCtx, definition_id: DefinitionId) -> Result<(), ParseError> { - debug_assert!(!self.type_lookup.contains_key(&definition_id), "base function already built"); - let definition = ctx.heap[definition_id].as_function(); + /// Builds base procedure type. + fn build_base_procedure_definition(&mut self, modules: &[Module], ctx: &mut PassCtx, definition_id: DefinitionId) -> Result<(), ParseError> { + debug_assert!(!self.definition_lookup.contains_key(&definition_id), "base function already built"); + let definition = ctx.heap[definition_id].as_procedure(); let root_id = definition.defined_in; // Check and construct return types and argument types. - debug_assert_eq!(definition.return_types.len(), 1, "not one return type"); - for return_type in &definition.return_types { + if let Some(return_type) = &definition.return_type { Self::check_member_parser_type( modules, ctx, root_id, return_type, definition.builtin )?; @@ -949,7 +979,7 @@ impl TypeTable { modules, ctx, root_id, ¶meter.parser_type, definition.builtin )?; - arguments.push(FunctionArgument{ + arguments.push(ProcedureArgument{ identifier: parameter.identifier.clone(), parser_type: parameter.parser_type.clone(), }); @@ -957,82 +987,36 @@ impl TypeTable { // Check conflict of identifiers Self::check_identifier_collision( - modules, root_id, &arguments, |arg| &arg.identifier, "function argument" + modules, root_id, &arguments, |arg| &arg.identifier, "procedure argument" )?; Self::check_poly_args_collision(modules, ctx, root_id, &definition.poly_vars)?; // Construct internal representation of function type + // TODO: Marking used polymorphic variables should take statements in + // the body into account. But currently we don't. Hence mark them all + // as being in-use. Note to self: true condition should be that the + // polymorphic variables are used in places where the resulting types + // are themselves truly polymorphic types (e.g. not a phantom type). let mut poly_vars = Self::create_polymorphic_variables(&definition.poly_vars); - for return_type in &definition.return_types { - Self::mark_used_polymorphic_variables(&mut poly_vars, return_type); - } - for argument in &arguments { - Self::mark_used_polymorphic_variables(&mut poly_vars, &argument.parser_type); - } - - let is_polymorph = poly_vars.iter().any(|arg| arg.is_in_use); - - self.type_lookup.insert(definition_id, DefinedType{ - ast_root: root_id, - ast_definition: definition_id, - definition: DefinedTypeVariant::Function(FunctionType{ return_types: definition.return_types.clone(), arguments }), - poly_vars, - is_polymorph - }); - - return Ok(()); - } - - /// Builds base component type. - fn build_base_component_definition(&mut self, modules: &[Module], ctx: &mut PassCtx, definition_id: DefinitionId) -> Result<(), ParseError> { - debug_assert!(!self.type_lookup.contains_key(&definition_id), "base component already built"); - - let definition = &ctx.heap[definition_id].as_component(); - let root_id = definition.defined_in; - - // Check the argument types - let mut arguments = Vec::with_capacity(definition.parameters.len()); - for parameter_id in &definition.parameters { - let parameter = &ctx.heap[*parameter_id]; - Self::check_member_parser_type( - modules, ctx, root_id, ¶meter.parser_type, false - )?; - - arguments.push(FunctionArgument{ - identifier: parameter.identifier.clone(), - parser_type: parameter.parser_type.clone(), - }); - } - - // Check conflict of identifiers - Self::check_identifier_collision( - modules, root_id, &arguments, |arg| &arg.identifier, "connector argument" - )?; - Self::check_poly_args_collision(modules, ctx, root_id, &definition.poly_vars)?; - - // Construct internal representation of component - // TODO: Marking used polymorphic variables on procedures requires - // making sure that each is used in the body. For now, mark them all - // as required. - let mut poly_vars = Self::create_polymorphic_variables(&definition.poly_vars); - // for argument in &arguments { - // Self::mark_used_polymorphic_variables(&mut poly_vars, &argument.parser_type); - // } for poly_var in &mut poly_vars { poly_var.is_in_use = true; } let is_polymorph = poly_vars.iter().any(|arg| arg.is_in_use); - self.type_lookup.insert(definition_id, DefinedType{ + self.definition_lookup.insert(definition_id, DefinedType{ ast_root: root_id, ast_definition: definition_id, - definition: DefinedTypeVariant::Component(ComponentType{ variant: definition.variant, arguments }), + definition: DefinedTypeVariant::Procedure(ProcedureType{ + kind: definition.kind, + return_type: definition.return_type.clone(), + arguments + }), poly_vars, is_polymorph }); - Ok(()) + return Ok(()); } /// Will check if the member type (field of a struct, embedded type in a @@ -1155,7 +1139,7 @@ impl TypeTable { /// Internal function that will detect type loops and check if they're /// resolvable. If so then the appropriate union variants will be marked as /// "living on heap". If not then a `ParseError` will be returned - fn detect_and_resolve_type_loops_for(&mut self, modules: &[Module], heap: &Heap, concrete_type: ConcreteType) -> Result<(), ParseError> { + fn detect_and_resolve_type_loops_for(&mut self, modules: &[Module], heap: &Heap, arch: &TargetArch, concrete_type: ConcreteType) -> Result<(), ParseError> { // Programmer notes: what happens here is the we call // `check_member_for_type_loops` for a particular type's member, and // then take action using the return value: @@ -1182,12 +1166,16 @@ impl TypeTable { debug_assert!(self.encountered_types.is_empty()); // Push the initial breadcrumb - let initial_breadcrumb = self.check_member_for_type_loops(&concrete_type); + let initial_breadcrumb = Self::check_member_for_type_loops( + &self.type_loop_breadcrumbs, &self.definition_lookup, &self.mono_type_lookup, + &mut self.mono_search_key, &concrete_type + ); + if let TypeLoopResult::PushBreadcrumb(definition_id, concrete_type) = initial_breadcrumb { - self.handle_new_breadcrumb_for_type_loops(definition_id, concrete_type); + self.handle_new_breadcrumb_for_type_loops(arch, definition_id, concrete_type); } else { - unreachable!(); - } + unreachable!() + }; // Enter into the main resolving loop while !self.type_loop_breadcrumbs.is_empty() { @@ -1195,12 +1183,15 @@ impl TypeTable { let breadcrumb_idx = self.type_loop_breadcrumbs.len() - 1; let mut breadcrumb = self.type_loop_breadcrumbs[breadcrumb_idx].clone(); - let monomorph = self.mono_lookup.get(breadcrumb.monomorph_idx); - let resolve_result = match &monomorph.variant { - MonomorphVariant::Enum => { + let mono_type = &self.mono_types[breadcrumb.type_id.0 as usize]; + let resolve_result = match &mono_type.variant { + MonoTypeVariant::Builtin => { + TypeLoopResult::TypeExists + } + MonoTypeVariant::Enum => { TypeLoopResult::TypeExists }, - MonomorphVariant::Union(monomorph) => { + MonoTypeVariant::Union(monomorph) => { let num_variants = monomorph.variants.len() as u32; let mut union_result = TypeLoopResult::TypeExists; @@ -1210,7 +1201,10 @@ impl TypeTable { while breadcrumb.next_embedded < num_embedded { let mono_embedded = &mono_variant.embedded[breadcrumb.next_embedded as usize]; - union_result = self.check_member_for_type_loops(&mono_embedded.concrete_type); + union_result = Self::check_member_for_type_loops( + &self.type_loop_breadcrumbs, &self.definition_lookup, &self.mono_type_lookup, + &mut self.mono_search_key, &mono_embedded.concrete_type + ); if union_result != TypeLoopResult::TypeExists { // In type loop or new breadcrumb pushed, so @@ -1227,13 +1221,16 @@ impl TypeTable { union_result }, - MonomorphVariant::Struct(monomorph) => { + MonoTypeVariant::Struct(monomorph) => { let num_fields = monomorph.fields.len() as u32; let mut struct_result = TypeLoopResult::TypeExists; while breadcrumb.next_member < num_fields { let mono_field = &monomorph.fields[breadcrumb.next_member as usize]; - struct_result = self.check_member_for_type_loops(&mono_field.concrete_type); + struct_result = Self::check_member_for_type_loops( + &self.type_loop_breadcrumbs, &self.definition_lookup, &self.mono_type_lookup, + &mut self.mono_search_key, &mono_field.concrete_type + ); if struct_result != TypeLoopResult::TypeExists { // Type loop or breadcrumb pushed, so break out of @@ -1246,14 +1243,17 @@ impl TypeTable { struct_result }, - MonomorphVariant::Procedure(_) => unreachable!(), - MonomorphVariant::Tuple(monomorph) => { + MonoTypeVariant::Procedure(_) => unreachable!(), + MonoTypeVariant::Tuple(monomorph) => { let num_members = monomorph.members.len() as u32; let mut tuple_result = TypeLoopResult::TypeExists; while breadcrumb.next_member < num_members { let tuple_member = &monomorph.members[breadcrumb.next_member as usize]; - tuple_result = self.check_member_for_type_loops(&tuple_member.concrete_type); + tuple_result = Self::check_member_for_type_loops( + &self.type_loop_breadcrumbs, &self.definition_lookup, &self.mono_type_lookup, + &mut self.mono_search_key, &tuple_member.concrete_type + ); if tuple_result != TypeLoopResult::TypeExists { break; @@ -1275,7 +1275,7 @@ impl TypeTable { TypeLoopResult::PushBreadcrumb(definition_id, concrete_type) => { // We recurse into the member type. self.type_loop_breadcrumbs[breadcrumb_idx] = breadcrumb; - self.handle_new_breadcrumb_for_type_loops(definition_id, concrete_type); + self.handle_new_breadcrumb_for_type_loops(arch, definition_id, concrete_type); }, TypeLoopResult::TypeLoop(first_idx) => { // Because we will be modifying breadcrumbs within the @@ -1290,24 +1290,22 @@ impl TypeTable { let breadcrumb = &mut self.type_loop_breadcrumbs[breadcrumb_idx]; let mut is_union = false; - let monomorph = self.mono_lookup.get_mut(breadcrumb.monomorph_idx); - // TODO: Match on monomorph directly here - match &mut monomorph.variant { - MonomorphVariant::Union(monomorph) => { - // Mark the currently processed variant as requiring heap - // allocation, then advance the *embedded* type. The loop above - // will then take care of advancing it to the next *member*. - let variant = &mut monomorph.variants[breadcrumb.next_member as usize]; - variant.lives_on_heap = true; - breadcrumb.next_embedded += 1; - is_union = true; - contains_union = true; - }, - _ => {}, // else: we don't care for now - } + // Check if type loop member is a union that may be + // broken up by moving some of its members to the heap. + let mono_type = &mut self.mono_types[breadcrumb.type_id.0 as usize]; + if let MonoTypeVariant::Union(union_type) = &mut mono_type.variant { + // Mark the variant that caused the loop as heap + // allocated to break the type loop. + let variant = &mut union_type.variants[breadcrumb.next_member as usize]; + variant.lives_on_heap = true; + breadcrumb.next_embedded += 1; + + is_union = true; + contains_union = true; + } // else: we don't care about the type for now loop_members.push(TypeLoopEntry{ - monomorph_idx: breadcrumb.monomorph_idx, + type_id: breadcrumb.type_id, is_union }); } @@ -1318,7 +1316,7 @@ impl TypeTable { // type loop error. This is because otherwise our // breadcrumb resolver ends up in an infinite loop. return Err(construct_type_loop_error( - self, &new_type_loop, modules, heap + &self.mono_types, &new_type_loop, modules, heap )); } @@ -1336,17 +1334,17 @@ impl TypeTable { // loop and that union ended up having variants that are not part of // a type loop. fn type_loop_source_span_and_message<'a>( - modules: &'a [Module], heap: &Heap, mono_lookup: &MonomorphTable, - definition_id: DefinitionId, monomorph_idx: i32, index_in_loop: usize + modules: &'a [Module], heap: &Heap, mono_types: &MonoTypeArray, + definition_id: DefinitionId, mono_type_id: TypeId, index_in_loop: usize ) -> (&'a InputSource, InputSpan, String) { // Note: because we will discover the type loop the *first* time we // instantiate a monomorph with the provided polymorphic arguments // (not all arguments are actually used in the type). We don't have // to care about a second instantiation where certain unused // polymorphic arguments are different. - let monomorph_type = &mono_lookup.get(monomorph_idx).concrete_type; + let mono_type = &mono_types[mono_type_id.0 as usize]; + let type_name = mono_type.concrete_type.display_name(heap); - let type_name = monomorph_type.display_name(&heap); let message = if index_in_loop == 0 { format!( "encountered an infinitely large type for '{}' (which can be fixed by \ @@ -1370,16 +1368,7 @@ impl TypeTable { ); } - fn retrieve_definition_id_if_possible(parts: &[ConcreteTypePart]) -> DefinitionId { - match &parts[0] { - ConcreteTypePart::Instance(v, _) | - ConcreteTypePart::Function(v, _) | - ConcreteTypePart::Component(v, _) => *v, - _ => DefinitionId::new_invalid(), - } - } - - fn construct_type_loop_error(table: &TypeTable, type_loop: &TypeLoop, modules: &[Module], heap: &Heap) -> ParseError { + fn construct_type_loop_error(mono_types: &MonoTypeArray, type_loop: &TypeLoop, modules: &[Module], heap: &Heap) -> ParseError { // Seek first entry to produce parse error. Then continue builder // pattern. This is the error case so efficiency can go home. let mut parse_error = None; @@ -1388,13 +1377,17 @@ impl TypeTable { let first_entry = &type_loop.members[next_member_index]; next_member_index += 1; - let first_definition_id = retrieve_definition_id_if_possible(&table.mono_lookup.get(first_entry.monomorph_idx).concrete_type.parts); - if first_definition_id.is_invalid() { + // Retrieve definition of first type in loop + let first_mono_type = &mono_types[first_entry.type_id.0 as usize]; + let first_definition_id = get_concrete_type_definition(&first_mono_type.concrete_type.parts); + if first_definition_id.is_none() { continue; } + let first_definition_id = first_definition_id.unwrap(); + // Produce error message for first type in loop let (first_module, first_span, first_message) = type_loop_source_span_and_message( - modules, heap, &table.mono_lookup, first_definition_id, first_entry.monomorph_idx, 0 + modules, heap, mono_types, first_definition_id, first_entry.type_id, 0 ); parse_error = Some(ParseError::new_error_at_span(first_module, first_span, first_message)); break; @@ -1405,13 +1398,15 @@ impl TypeTable { let mut error_counter = 1; for member_idx in next_member_index..type_loop.members.len() { let entry = &type_loop.members[member_idx]; - let definition_id = retrieve_definition_id_if_possible(&table.mono_lookup.get(entry.monomorph_idx).concrete_type.parts); - if definition_id.is_invalid() { - continue; // dont display tuples + let mono_type = &mono_types[entry.type_id.0 as usize]; + let definition_id = get_concrete_type_definition(&mono_type.concrete_type.parts); + if definition_id.is_none() { + continue; } + let definition_id = definition_id.unwrap(); let (module, span, message) = type_loop_source_span_and_message( - modules, heap, &table.mono_lookup, definition_id, entry.monomorph_idx, error_counter + modules, heap, mono_types, definition_id, entry.type_id, error_counter ); parse_error = parse_error.with_info_at_span(module, span, message); error_counter += 1; @@ -1426,9 +1421,9 @@ impl TypeTable { for entry in &type_loop.members { if entry.is_union { - let monomorph = self.mono_lookup.get(entry.monomorph_idx).variant.as_union(); - debug_assert!(!monomorph.variants.is_empty()); // otherwise it couldn't be part of the type loop - let has_stack_variant = monomorph.variants.iter().any(|variant| !variant.lives_on_heap); + let mono_type = self.mono_types[entry.type_id.0 as usize].variant.as_union(); + debug_assert!(!mono_type.variants.is_empty()); // otherwise it couldn't be part of the type loop + let has_stack_variant = mono_type.variants.iter().any(|variant| !variant.lives_on_heap); if has_stack_variant { can_be_broken = true; break; @@ -1438,7 +1433,7 @@ impl TypeTable { if !can_be_broken { // Construct a type loop error - return Err(construct_type_loop_error(self, type_loop, modules, heap)); + return Err(construct_type_loop_error(&self.mono_types, type_loop, modules, heap)); } } @@ -1456,35 +1451,26 @@ impl TypeTable { /// don't do any modifications of internal types here. Hence: if we /// return `PushBreadcrumb` then call `handle_new_breadcrumb_for_type_loops` /// to take care of storing the appropriate types. - fn check_member_for_type_loops(&self, definition_type: &ConcreteType) -> TypeLoopResult { + fn check_member_for_type_loops( + breadcrumbs: &[TypeLoopBreadcrumb], definition_map: &DefinitionMap, mono_type_map: &MonoTypeMap, + mono_key: &mut MonoSearchKey, concrete_type: &ConcreteType + ) -> TypeLoopResult { use ConcreteTypePart as CTP; // Depending on the type, lookup if the type has already been visited // (i.e. either already has its memory layed out, or is part of a type // loop because we've already visited the type) - debug_assert!(!definition_type.parts.is_empty()); - let (definition_id, monomorph_index) = match &definition_type.parts[0] { - CTP::Tuple(_) => { - let monomorph_index = self.mono_lookup.get_monomorph_index(&definition_type.parts, &[]); - - (DefinitionId::new_invalid(), monomorph_index) - }, - CTP::Instance(definition_id, _) | - CTP::Function(definition_id, _) | - CTP::Component(definition_id, _) => { - let base_type = self.type_lookup.get(definition_id).unwrap(); - let monomorph_index = self.mono_lookup.get_monomorph_index(&definition_type.parts, &base_type.poly_vars); - - (*definition_id, monomorph_index) - }, - _ => { - return TypeLoopResult::TypeExists - }, + debug_assert!(!concrete_type.parts.is_empty()); + let definition_id = if let ConcreteTypePart::Instance(definition_id, _) = concrete_type.parts[0] { + definition_id + } else { + DefinitionId::new_invalid() }; - if let Some(monomorph_index) = monomorph_index { - for (breadcrumb_idx, breadcrumb) in self.type_loop_breadcrumbs.iter().enumerate() { - if breadcrumb.monomorph_idx == monomorph_index { + Self::set_search_key_to_type(mono_key, definition_map, &concrete_type.parts); + if let Some(type_id) = mono_type_map.get(mono_key).copied() { + for (breadcrumb_idx, breadcrumb) in breadcrumbs.iter().enumerate() { + if breadcrumb.type_id == type_id { return TypeLoopResult::TypeLoop(breadcrumb_idx); } } @@ -1494,55 +1480,106 @@ impl TypeTable { // Type is not yet known, so we need to insert it into the lookup and // push a new breadcrumb. - return TypeLoopResult::PushBreadcrumb(definition_id, definition_type.clone()); + return TypeLoopResult::PushBreadcrumb(definition_id, concrete_type.clone()); } /// Handles the `PushBreadcrumb` result for a `check_member_for_type_loops` - /// call. - fn handle_new_breadcrumb_for_type_loops(&mut self, definition_id: DefinitionId, definition_type: ConcreteType) { + /// call. Will preallocate entries in the monomorphed type storage (with + /// all memory properties zeroed). + fn handle_new_breadcrumb_for_type_loops(&mut self, arch: &TargetArch, definition_id: DefinitionId, concrete_type: ConcreteType) { use DefinedTypeVariant as DTV; use ConcreteTypePart as CTP; let mut is_union = false; - let monomorph_index = match &definition_type.parts[0] { + let type_id = match &concrete_type.parts[0] { + // Builtin types + CTP::Void | CTP::Message | CTP::Bool | + CTP::UInt8 | CTP::UInt16 | CTP::UInt32 | CTP::UInt64 | + CTP::SInt8 | CTP::SInt16 | CTP::SInt32 | CTP::SInt64 | + CTP::Character | CTP::String | + CTP::Array | CTP::Slice | CTP::Input | CTP::Output | CTP::Pointer => { + // Insert the entry for the builtin type, we should be able to + // immediately "steal" the size from the preinserted builtins. + let base_type_id = match &concrete_type.parts[0] { + CTP::Void => arch.void_type_id, + CTP::Message => arch.message_type_id, + CTP::Bool => arch.bool_type_id, + CTP::UInt8 => arch.uint8_type_id, + CTP::UInt16 => arch.uint16_type_id, + CTP::UInt32 => arch.uint32_type_id, + CTP::UInt64 => arch.uint64_type_id, + CTP::SInt8 => arch.sint8_type_id, + CTP::SInt16 => arch.sint16_type_id, + CTP::SInt32 => arch.sint32_type_id, + CTP::SInt64 => arch.sint64_type_id, + CTP::Character => arch.char_type_id, + CTP::String => arch.string_type_id, + CTP::Array => arch.array_type_id, + CTP::Slice => arch.slice_type_id, + CTP::Input => arch.input_type_id, + CTP::Output => arch.output_type_id, + CTP::Pointer => arch.pointer_type_id, + _ => unreachable!(), + }; + let base_type = &self.mono_types[base_type_id.0 as usize]; + let base_type_size = base_type.size; + let base_type_alignment = base_type.alignment; + + let type_id = TypeId(self.mono_types.len() as i64); + Self::set_search_key_to_type(&mut self.mono_search_key, &self.definition_lookup, &concrete_type.parts); + self.mono_type_lookup.insert(self.mono_search_key.clone(), type_id); + self.mono_types.push(MonoType{ + type_id, + concrete_type, + size: base_type_size, + alignment: base_type_alignment, + variant: MonoTypeVariant::Builtin + }); + + type_id + }, + // User-defined types CTP::Tuple(num_embedded) => { debug_assert!(definition_id.is_invalid()); // because tuples do not have an associated `DefinitionId` let mut members = Vec::with_capacity(*num_embedded as usize); - for section in ConcreteTypeIter::new(&definition_type.parts, 0) { + for section in ConcreteTypeIter::new(&concrete_type.parts, 0) { members.push(TupleMonomorphMember{ + type_id: TypeId::new_invalid(), concrete_type: ConcreteType{ parts: Vec::from(section) }, size: 0, alignment: 0, offset: 0 }); } - let mono_index = self.mono_lookup.insert_with_zero_size_and_alignment( - definition_type, &[], - MonomorphVariant::Tuple(TupleMonomorph{ - members, - }) - ); - mono_index + let type_id = TypeId(self.mono_types.len() as i64); + Self::set_search_key_to_tuple(&mut self.mono_search_key, &self.definition_lookup, &concrete_type.parts); + self.mono_type_lookup.insert(self.mono_search_key.clone(), type_id); + self.mono_types.push(MonoType::new_empty(type_id, concrete_type, MonoTypeVariant::Tuple(TupleMonomorph{ members }))); + + type_id }, CTP::Instance(_check_definition_id, _) => { debug_assert_eq!(definition_id, *_check_definition_id); // because this is how `definition_id` was determined - let base_type = self.type_lookup.get_mut(&definition_id).unwrap(); - let monomorph_index = match &mut base_type.definition { + + Self::set_search_key_to_type(&mut self.mono_search_key, &self.definition_lookup, &concrete_type.parts); + let base_type = self.definition_lookup.get(&definition_id).unwrap(); + let type_id = match &base_type.definition { DTV::Enum(definition) => { // The enum is a bit exceptional in that when we insert // it we we will immediately set its size/alignment: // there is nothing to compute here. debug_assert!(definition.size != 0 && definition.alignment != 0); - let mono_index = self.mono_lookup.insert_with_zero_size_and_alignment( - definition_type, &base_type.poly_vars, MonomorphVariant::Enum - ); - let mono_type = self.mono_lookup.get_mut(mono_index); + let type_id = TypeId(self.mono_types.len() as i64); + self.mono_type_lookup.insert(self.mono_search_key.clone(), type_id); + self.mono_types.push(MonoType::new_empty(type_id, concrete_type, MonoTypeVariant::Enum)); + + let mono_type = &mut self.mono_types[type_id.0 as usize]; mono_type.size = definition.size; mono_type.alignment = definition.alignment; - mono_index + type_id }, DTV::Union(definition) => { // Create all the variants with their concrete types @@ -1550,8 +1587,9 @@ impl TypeTable { for poly_variant in &definition.variants { let mut mono_embedded = Vec::with_capacity(poly_variant.embedded.len()); for poly_embedded in &poly_variant.embedded { - let mono_concrete = Self::construct_concrete_type(poly_embedded, &definition_type); + let mono_concrete = Self::construct_concrete_type(poly_embedded, &concrete_type); mono_embedded.push(UnionMonomorphEmbedded{ + type_id: TypeId::new_invalid(), concrete_type: mono_concrete, size: 0, alignment: 0, @@ -1565,24 +1603,27 @@ impl TypeTable { }) } - let mono_index = self.mono_lookup.insert_with_zero_size_and_alignment( - definition_type, &base_type.poly_vars, - MonomorphVariant::Union(UnionMonomorph{ - variants: mono_variants, - tag_size: definition.tag_size, - heap_size: 0, - heap_alignment: 0 - }) - ); + let type_id = TypeId(self.mono_types.len() as i64); + let tag_size = definition.tag_size; + Self::set_search_key_to_type(&mut self.mono_search_key, &self.definition_lookup, &concrete_type.parts); + self.mono_type_lookup.insert(self.mono_search_key.clone(), type_id); + self.mono_types.push(MonoType::new_empty(type_id, concrete_type, MonoTypeVariant::Union(UnionMonomorph{ + variants: mono_variants, + tag_size, + heap_size: 0, + heap_alignment: 0, + }))); is_union = true; - mono_index + type_id }, DTV::Struct(definition) => { + // Create fields let mut mono_fields = Vec::with_capacity(definition.fields.len()); for poly_field in &definition.fields { - let mono_concrete = Self::construct_concrete_type(&poly_field.parser_type, &definition_type); + let mono_concrete = Self::construct_concrete_type(&poly_field.parser_type, &concrete_type); mono_fields.push(StructMonomorphField{ + type_id: TypeId::new_invalid(), concrete_type: mono_concrete, size: 0, alignment: 0, @@ -1590,30 +1631,28 @@ impl TypeTable { }) } - let mono_index = self.mono_lookup.insert_with_zero_size_and_alignment( - definition_type, &base_type.poly_vars, - MonomorphVariant::Struct(StructMonomorph{ fields: mono_fields }) - ); + let type_id = TypeId(self.mono_types.len() as i64); + Self::set_search_key_to_type(&mut self.mono_search_key, &self.definition_lookup, &concrete_type.parts); + self.mono_type_lookup.insert(self.mono_search_key.clone(), type_id); + self.mono_types.push(MonoType::new_empty(type_id, concrete_type, MonoTypeVariant::Struct(StructMonomorph{ + fields: mono_fields, + }))); - mono_index + type_id }, - DTV::Function(_) | DTV::Component(_) => { + DTV::Procedure(_) => { unreachable!("pushing type resolving breadcrumb for procedure type") }, }; - monomorph_index + type_id }, - _ => unreachable!(), + CTP::Function(_, _) | CTP::Component(_, _) => todo!("function pointers"), }; - self.encountered_types.push(TypeLoopEntry{ - monomorph_idx: monomorph_index, - is_union, - }); - + self.encountered_types.push(TypeLoopEntry{ type_id, is_union }); self.type_loop_breadcrumbs.push(TypeLoopBreadcrumb{ - monomorph_idx: monomorph_index, + type_id, next_member: 0, next_embedded: 0, }); @@ -1664,7 +1703,7 @@ impl TypeTable { // Not builtin, but if all code is working correctly, we only care // about the polymorphic argument at this point. if let PTV::PolymorphicArgument(_container_definition_id, poly_arg_idx) = member_part.variant { - debug_assert_eq!(_container_definition_id, get_concrete_type_definition(container_type)); + debug_assert_eq!(_container_definition_id, get_concrete_type_definition(&container_type.parts).unwrap()); let mut container_iter = container_type.embedded_iter(0); for _ in 0..poly_arg_idx { @@ -1701,17 +1740,21 @@ impl TypeTable { // optimization, we're working around borrowing rules here. // Just finished type loop detection, so we're left with the encountered - // types only + // types only. If we don't have any (a builtin type's monomorph was + // added to the type table) then this function shouldn't be called at + // all. debug_assert!(self.type_loops.is_empty()); debug_assert!(!self.encountered_types.is_empty()); debug_assert!(self.memory_layout_breadcrumbs.is_empty()); debug_assert!(self.size_alignment_stack.is_empty()); + let (ptr_size, ptr_align) = self.mono_types[arch.pointer_type_id.0 as usize].get_size_alignment().unwrap(); + // Push the first entry (the type we originally started with when we // were detecting type loops) let first_entry = &self.encountered_types[0]; self.memory_layout_breadcrumbs.push(MemoryBreadcrumb{ - monomorph_idx: first_entry.monomorph_idx, + type_id: first_entry.type_id, next_member: 0, next_embedded: 0, first_size_alignment_idx: 0, @@ -1722,16 +1765,16 @@ impl TypeTable { let cur_breadcrumb_idx = self.memory_layout_breadcrumbs.len() - 1; let mut breadcrumb = self.memory_layout_breadcrumbs[cur_breadcrumb_idx].clone(); - let mono_type = self.mono_lookup.get(breadcrumb.monomorph_idx); + let mono_type = &self.mono_types[breadcrumb.type_id.0 as usize]; match &mono_type.variant { - MonomorphVariant::Enum => { + MonoTypeVariant::Builtin | MonoTypeVariant::Enum => { // Size should already be computed - if cfg!(debug_assertions) { - let mono_type = self.mono_lookup.get(breadcrumb.monomorph_idx); + dbg_code!({ + let mono_type = &self.mono_types[breadcrumb.type_id.0 as usize]; debug_assert!(mono_type.size != 0 && mono_type.alignment != 0); - } + }); }, - MonomorphVariant::Union(mono_type) => { + MonoTypeVariant::Union(mono_type) => { // Retrieve size/alignment of each embedded type. We do not // compute the offsets or total type sizes yet. let num_variants = mono_type.variants.len() as u32; @@ -1746,7 +1789,12 @@ impl TypeTable { let num_embedded = mono_variant.embedded.len() as u32; while breadcrumb.next_embedded < num_embedded { let mono_embedded = &mono_variant.embedded[breadcrumb.next_embedded as usize]; - match self.get_memory_layout_or_breadcrumb(arch, &mono_embedded.concrete_type.parts) { + let layout_result = Self::get_memory_layout_or_breadcrumb( + &self.definition_lookup, &self.mono_type_lookup, &self.mono_types, + &mut self.mono_search_key, arch, &mono_embedded.concrete_type.parts, + self.size_alignment_stack.len() + ); + match layout_result { MemoryLayoutResult::TypeExists(size, alignment) => { self.size_alignment_stack.push((size, alignment)); }, @@ -1772,19 +1820,18 @@ impl TypeTable { let mut max_size = mono_type.tag_size; let mut max_alignment = mono_type.tag_size; - let mono_info = self.mono_lookup.get_mut(breadcrumb.monomorph_idx); - let mono_type = mono_info.variant.as_union_mut(); + let mono_type = &mut self.mono_types[breadcrumb.type_id.0 as usize]; + let union_type = mono_type.variant.as_union_mut(); let mut size_alignment_idx = breadcrumb.first_size_alignment_idx as usize; - for variant in &mut mono_type.variants { + for variant in &mut union_type.variants { // We're doing stack computations, so always start with // the tag size/alignment. - let mut variant_offset = mono_type.tag_size; - let mut variant_alignment = mono_type.tag_size; + let mut variant_offset = union_type.tag_size; + let mut variant_alignment = union_type.tag_size; if variant.lives_on_heap { // Variant lives on heap, so just a pointer - let (ptr_size, ptr_align) = arch.pointer_size_alignment; align_offset_to(&mut variant_offset, ptr_align); variant_offset += ptr_size; @@ -1810,18 +1857,23 @@ impl TypeTable { max_alignment = max_alignment.max(variant_alignment); } - mono_info.size = max_size; - mono_info.alignment = max_alignment; + mono_type.size = max_size; + mono_type.alignment = max_alignment; self.size_alignment_stack.truncate(breadcrumb.first_size_alignment_idx as usize); }, - MonomorphVariant::Struct(mono_type) => { + MonoTypeVariant::Struct(mono_type) => { // Retrieve size and alignment of each struct member. We'll // compute the offsets once all of those are known let num_fields = mono_type.fields.len() as u32; while breadcrumb.next_member < num_fields { let mono_field = &mono_type.fields[breadcrumb.next_member as usize]; - match self.get_memory_layout_or_breadcrumb(arch, &mono_field.concrete_type.parts) { + let layout_result = Self::get_memory_layout_or_breadcrumb( + &self.definition_lookup, &self.mono_type_lookup, &self.mono_types, + &mut self.mono_search_key, arch, &mono_field.concrete_type.parts, + self.size_alignment_stack.len() + ); + match layout_result { MemoryLayoutResult::TypeExists(size, alignment) => { self.size_alignment_stack.push((size, alignment)) }, @@ -1839,11 +1891,11 @@ impl TypeTable { let mut cur_offset = 0; let mut max_alignment = 1; - let mono_info = self.mono_lookup.get_mut(breadcrumb.monomorph_idx); - let mono_type = mono_info.variant.as_struct_mut(); + let mono_type = &mut self.mono_types[breadcrumb.type_id.0 as usize]; + let struct_type = mono_type.variant.as_struct_mut(); let mut size_alignment_idx = breadcrumb.first_size_alignment_idx as usize; - for field in &mut mono_type.fields { + for field in &mut struct_type.fields { let (size, alignment) = self.size_alignment_stack[size_alignment_idx]; field.size = size; field.alignment = alignment; @@ -1856,18 +1908,23 @@ impl TypeTable { max_alignment = max_alignment.max(alignment); } - mono_info.size = cur_offset; - mono_info.alignment = max_alignment; + mono_type.size = cur_offset; + mono_type.alignment = max_alignment; self.size_alignment_stack.truncate(breadcrumb.first_size_alignment_idx as usize); }, - MonomorphVariant::Procedure(_) => { + MonoTypeVariant::Procedure(_) => { unreachable!(); }, - MonomorphVariant::Tuple(mono_type) => { + MonoTypeVariant::Tuple(mono_type) => { let num_members = mono_type.members.len() as u32; while breadcrumb.next_member < num_members { let mono_member = &mono_type.members[breadcrumb.next_member as usize]; - match self.get_memory_layout_or_breadcrumb(arch, &mono_member.concrete_type.parts) { + let layout_result = Self::get_memory_layout_or_breadcrumb( + &self.definition_lookup, &self.mono_type_lookup, &self.mono_types, + &mut self.mono_search_key, arch, &mono_member.concrete_type.parts, + self.size_alignment_stack.len() + ); + match layout_result { MemoryLayoutResult::TypeExists(size, alignment) => { self.size_alignment_stack.push((size, alignment)); }, @@ -1885,15 +1942,15 @@ impl TypeTable { let mut cur_offset = 0; let mut max_alignment = 1; - let mono_info = self.mono_lookup.get_mut(breadcrumb.monomorph_idx); - let mono_type = mono_info.variant.as_tuple_mut(); + let mono_type = &mut self.mono_types[breadcrumb.type_id.0 as usize]; + let mono_tuple = mono_type.variant.as_tuple_mut(); let mut size_alignment_index = breadcrumb.first_size_alignment_idx as usize; for member_index in 0..num_members { let (member_size, member_alignment) = self.size_alignment_stack[size_alignment_index]; align_offset_to(&mut cur_offset, member_alignment); size_alignment_index += 1; - let member = &mut mono_type.members[member_index as usize]; + let member = &mut mono_tuple.members[member_index as usize]; member.size = member_size; member.alignment = member_alignment; member.offset = cur_offset; @@ -1902,8 +1959,8 @@ impl TypeTable { max_alignment = max_alignment.max(member_alignment); } - mono_info.size = cur_offset; - mono_info.alignment = max_alignment; + mono_type.size = cur_offset; + mono_type.alignment = max_alignment; self.size_alignment_stack.truncate(breadcrumb.first_size_alignment_idx as usize); }, } @@ -1925,7 +1982,7 @@ impl TypeTable { // First pass, use buffer to store size/alignment to prevent // borrowing issues. - let mono_type = self.mono_lookup.get(entry.monomorph_idx).variant.as_union(); + let mono_type = self.mono_types[entry.type_id.0 as usize].variant.as_union(); for variant in &mono_type.variants { if !variant.lives_on_heap { continue; @@ -1934,17 +1991,22 @@ impl TypeTable { debug_assert!(!variant.embedded.is_empty()); for embedded in &variant.embedded { - match self.get_memory_layout_or_breadcrumb(arch, &embedded.concrete_type.parts) { + let layout_result = Self::get_memory_layout_or_breadcrumb( + &self.definition_lookup, &self.mono_type_lookup, &self.mono_types, + &mut self.mono_search_key, arch, &embedded.concrete_type.parts, + self.size_alignment_stack.len() + ); + match layout_result { MemoryLayoutResult::TypeExists(size, alignment) => { self.size_alignment_stack.push((size, alignment)); }, - _ => unreachable!(), + _ => unreachable!(), // type was not truly infinite, so type must have been found } } } // Second pass, apply the size/alignment values in our buffer - let mono_type = self.mono_lookup.get_mut(entry.monomorph_idx).variant.as_union_mut(); + let mono_type = self.mono_types[entry.type_id.0 as usize].variant.as_union_mut(); let mut max_size = 0; let mut max_alignment = 1; @@ -1990,64 +2052,66 @@ impl TypeTable { /// is called *after* type loops have been succesfully resolved. Hence we /// may assume that all monomorph entries exist, but we may not assume that /// those entries already have their size/alignment computed. - fn get_memory_layout_or_breadcrumb(&self, arch: &TargetArch, parts: &[ConcreteTypePart]) -> MemoryLayoutResult { + // Passed parameters are messy. But need to strike balance between borrowing + // and allocations in hot loops. So it is what it is. + fn get_memory_layout_or_breadcrumb( + definition_map: &DefinitionMap, mono_type_map: &MonoTypeMap, mono_types: &MonoTypeArray, + search_key: &mut MonoSearchKey, arch: &TargetArch, parts: &[ConcreteTypePart], + size_alignment_stack_len: usize, + ) -> MemoryLayoutResult { use ConcreteTypePart as CTP; debug_assert!(!parts.is_empty()); - let (builtin_size, builtin_alignment) = match parts[0] { - CTP::Void => (0, 1), - CTP::Message => arch.array_size_alignment, - CTP::Bool => (1, 1), - CTP::UInt8 => (1, 1), - CTP::UInt16 => (2, 2), - CTP::UInt32 => (4, 4), - CTP::UInt64 => (8, 8), - CTP::SInt8 => (1, 1), - CTP::SInt16 => (2, 2), - CTP::SInt32 => (4, 4), - CTP::SInt64 => (8, 8), - CTP::Character => (4, 4), - CTP::String => arch.string_size_alignment, - CTP::Array => arch.array_size_alignment, - CTP::Slice => arch.array_size_alignment, - CTP::Input => arch.port_size_alignment, - CTP::Output => arch.port_size_alignment, + let type_id = match parts[0] { + CTP::Void => arch.void_type_id, + CTP::Message => arch.message_type_id, + CTP::Bool => arch.bool_type_id, + CTP::UInt8 => arch.uint8_type_id, + CTP::UInt16 => arch.uint16_type_id, + CTP::UInt32 => arch.uint32_type_id, + CTP::UInt64 => arch.uint64_type_id, + CTP::SInt8 => arch.sint8_type_id, + CTP::SInt16 => arch.sint16_type_id, + CTP::SInt32 => arch.sint32_type_id, + CTP::SInt64 => arch.sint64_type_id, + CTP::Character => arch.char_type_id, + CTP::String => arch.string_type_id, + CTP::Array => arch.array_type_id, + CTP::Slice => arch.slice_type_id, + CTP::Input => arch.input_type_id, + CTP::Output => arch.output_type_id, + CTP::Pointer => arch.pointer_type_id, CTP::Tuple(_) => { - let mono_index = self.mono_lookup.get_monomorph_index(parts, &[]).unwrap(); - if let Some((size, alignment)) = self.mono_lookup.get_monomorph_size_alignment(mono_index) { - return MemoryLayoutResult::TypeExists(size, alignment); - } else { - return MemoryLayoutResult::PushBreadcrumb(MemoryBreadcrumb{ - monomorph_idx: mono_index, - next_member: 0, - next_embedded: 0, - first_size_alignment_idx: self.size_alignment_stack.len() as u32, - }) - } + Self::set_search_key_to_tuple(search_key, definition_map, parts); + let type_id = mono_type_map.get(&search_key).copied().unwrap(); + + type_id }, CTP::Instance(definition_id, _) => { // Retrieve entry and the specific monomorph index by applying // the full concrete type. - let entry = self.type_lookup.get(&definition_id).unwrap(); - let mono_index = self.mono_lookup.get_monomorph_index(parts, &entry.poly_vars).unwrap(); + let definition_type = definition_map.get(&definition_id).unwrap(); + search_key.set(parts, &definition_type.poly_vars); + let type_id = mono_type_map.get(&search_key).copied().unwrap(); - if let Some((size, alignment)) = self.mono_lookup.get_monomorph_size_alignment(mono_index) { - return MemoryLayoutResult::TypeExists(size, alignment); - } else { - return MemoryLayoutResult::PushBreadcrumb(MemoryBreadcrumb{ - monomorph_idx: mono_index, - next_member: 0, - next_embedded: 0, - first_size_alignment_idx: self.size_alignment_stack.len() as u32, - }); - } + type_id }, CTP::Function(_, _) | CTP::Component(_, _) => { todo!("storage for 'function pointers'"); } }; - return MemoryLayoutResult::TypeExists(builtin_size, builtin_alignment); + let mono_type = &mono_types[type_id.0 as usize]; + if let Some((size, alignment)) = mono_type.get_size_alignment() { + return MemoryLayoutResult::TypeExists(size, alignment); + } else { + return MemoryLayoutResult::PushBreadcrumb(MemoryBreadcrumb{ + type_id, + next_member: 0, + next_embedded: 0, + first_size_alignment_idx: size_alignment_stack_len as u32, + }); + } } /// Returns tag concrete type (always a builtin integer type), the size of @@ -2102,6 +2166,56 @@ impl TypeTable { } } } + + /// Sets the search key to a specific type. + fn set_search_key_to_type(search_key: &mut MonoSearchKey, definition_map: &DefinitionMap, type_parts: &[ConcreteTypePart]) { + use ConcreteTypePart as CTP; + + match type_parts[0] { + // Builtin types without any embedded types + CTP::Void | CTP::Message | CTP::Bool | + CTP::UInt8 | CTP::UInt16 | CTP::UInt32 | CTP::UInt64 | + CTP::SInt8 | CTP::SInt16 | CTP::SInt32 | CTP::SInt64 | + CTP::Character | CTP::String => { + debug_assert_eq!(type_parts.len(), 1); + search_key.set_top_type(type_parts[0]); + }, + // Builtin types with a single nested type + CTP::Array | CTP::Slice | CTP::Input | CTP::Output | CTP::Pointer => { + debug_assert_eq!(type_parts[0].num_embedded(), 1); + search_key.set(type_parts, &POLY_VARS_IN_USE[..1]) + }, + // User-defined types + CTP::Tuple(_) => { + Self::set_search_key_to_tuple(search_key, definition_map, type_parts); + }, + CTP::Instance(definition_id, _) => { + let definition_type = definition_map.get(&definition_id).unwrap(); + search_key.set(type_parts, &definition_type.poly_vars); + }, + CTP::Function(_, _) | CTP::Component(_, _) => { + todo!("implement function pointers") + }, + } + } + + fn set_search_key_to_tuple(search_key: &mut MonoSearchKey, definition_map: &DefinitionMap, type_parts: &[ConcreteTypePart]) { + dbg_code!({ + let is_tuple = if let ConcreteTypePart::Tuple(_) = type_parts[0] { true } else { false }; + assert!(is_tuple); + }); + search_key.set_top_type(type_parts[0]); + for subtree in ConcreteTypeIter::new(type_parts, 0) { + if let Some(definition_id) = get_concrete_type_definition(subtree) { + // A definition, so retrieve poly var usage info + let definition_type = definition_map.get(&definition_id).unwrap(); + search_key.push_subtree(subtree, &definition_type.poly_vars); + } else { + // Not a definition, so all type information is important + search_key.push_subtype(subtree, true); + } + } + } } #[inline] @@ -2113,11 +2227,17 @@ fn align_offset_to(offset: &mut usize, alignment: usize) { } #[inline] -fn get_concrete_type_definition(concrete: &ConcreteType) -> DefinitionId { - if let ConcreteTypePart::Instance(definition_id, _) = concrete.parts[0] { - return definition_id; - } else { - debug_assert!(false, "passed {:?} to the type table", concrete); - return DefinitionId::new_invalid() +fn get_concrete_type_definition(concrete_parts: &[ConcreteTypePart]) -> Option { + match concrete_parts[0] { + ConcreteTypePart::Instance(definition_id, _) => { + return Some(definition_id) + }, + ConcreteTypePart::Function(definition_id, _) | + ConcreteTypePart::Component(definition_id, _) => { + return Some(definition_id.upcast()); + }, + _ => { + return None; + }, } } \ No newline at end of file diff --git a/src/protocol/parser/visitor.rs b/src/protocol/parser/visitor.rs index 3d62f1c603b4b29b8d640dee5b355d8186fc3383..8bf927f563aeb8a2819f703c87477c4db8d60fb4 100644 --- a/src/protocol/parser/visitor.rs +++ b/src/protocol/parser/visitor.rs @@ -6,8 +6,10 @@ use crate::protocol::symbol_table::{SymbolTable}; type Unit = (); pub(crate) type VisitorResult = Result; -/// Globally configured vector capacity for buffers in visitor implementations -pub(crate) const BUFFER_INIT_CAPACITY: usize = 256; +/// Globally configured capacity for large-ish buffers in visitor impls +pub(crate) const BUFFER_INIT_CAP_LARGE: usize = 256; +/// Globally configured capacity for small-ish buffers in visitor impls +pub(crate) const BUFFER_INIT_CAP_SMALL: usize = 64; /// General context structure that is used while traversing the AST. pub(crate) struct Ctx<'p> { @@ -30,218 +32,243 @@ impl<'p> Ctx<'p> { } } -/// Visitor is a generic trait that will fully walk the AST. The default -/// implementation of the visitors is to not recurse. The exception is the -/// top-level `visit_definition`, `visit_stmt` and `visit_expr` methods, which -/// call the appropriate visitor function. -pub(crate) trait Visitor { - // Entry point - fn visit_module(&mut self, ctx: &mut Ctx) -> VisitorResult { - let mut def_index = 0; - let module_root_id = ctx.modules[ctx.module_idx].root_id; - loop { - let definition_id = { - let root = &ctx.heap[module_root_id]; - if def_index >= root.definitions.len() { - return Ok(()) - } - - root.definitions[def_index] - }; - - self.visit_definition(ctx, definition_id)?; - def_index += 1; - } - } - - // Definitions - // --- enum matching - fn visit_definition(&mut self, ctx: &mut Ctx, id: DefinitionId) -> VisitorResult { - match &ctx.heap[id] { - Definition::Enum(def) => { - let def = def.this; - self.visit_enum_definition(ctx, def) - }, - Definition::Union(def) => { - let def = def.this; - self.visit_union_definition(ctx, def) - } - Definition::Struct(def) => { - let def = def.this; - self.visit_struct_definition(ctx, def) - }, - Definition::Component(def) => { - let def = def.this; - self.visit_component_definition(ctx, def) - }, - Definition::Function(def) => { - let def = def.this; - self.visit_function_definition(ctx, def) - } - } - } - - // --- enum variant handling - fn visit_enum_definition(&mut self, _ctx: &mut Ctx, _id: EnumDefinitionId) -> VisitorResult { Ok(()) } - fn visit_union_definition(&mut self, _ctx: &mut Ctx, _id: UnionDefinitionId) -> VisitorResult{ Ok(()) } - fn visit_struct_definition(&mut self, _ctx: &mut Ctx, _id: StructDefinitionId) -> VisitorResult { Ok(()) } - fn visit_component_definition(&mut self, _ctx: &mut Ctx, _id: ComponentDefinitionId) -> VisitorResult { Ok(()) } - fn visit_function_definition(&mut self, _ctx: &mut Ctx, _id: FunctionDefinitionId) -> VisitorResult { Ok(()) } - - // Statements - // --- enum matching - fn visit_stmt(&mut self, ctx: &mut Ctx, id: StatementId) -> VisitorResult { - match &ctx.heap[id] { +/// Implements the logic that checks the statement union retrieved from the +/// AST and calls the appropriate visit function. This entire macro assumes that +/// `$this` points to `self`, `$stmt` is the statement of type `Statement`, +/// `$ctx` is the context passed to all the visitor calls (of the form +/// `visit_x_stmt(context, id)`) and `$default_return` is the default return +/// value for the statements that will not be visited. +macro_rules! visitor_recursive_statement_impl { + ($this:expr, $stmt:expr, $ctx:expr, $default_return:expr) => { + match $stmt { Statement::Block(stmt) => { let this = stmt.this; - self.visit_block_stmt(ctx, this) + $this.visit_block_stmt($ctx, this) }, - Statement::EndBlock(_stmt) => Ok(()), + Statement::EndBlock(_stmt) => $default_return, Statement::Local(stmt) => { let this = stmt.this(); - self.visit_local_stmt(ctx, this) + $this.visit_local_stmt($ctx, this) }, Statement::Labeled(stmt) => { let this = stmt.this; - self.visit_labeled_stmt(ctx, this) + $this.visit_labeled_stmt($ctx, this) }, Statement::If(stmt) => { let this = stmt.this; - self.visit_if_stmt(ctx, this) + $this.visit_if_stmt($ctx, this) }, - Statement::EndIf(_stmt) => Ok(()), + Statement::EndIf(_stmt) => $default_return, Statement::While(stmt) => { let this = stmt.this; - self.visit_while_stmt(ctx, this) + $this.visit_while_stmt($ctx, this) }, - Statement::EndWhile(_stmt) => Ok(()), + Statement::EndWhile(_stmt) => $default_return, Statement::Break(stmt) => { let this = stmt.this; - self.visit_break_stmt(ctx, this) + $this.visit_break_stmt($ctx, this) }, Statement::Continue(stmt) => { let this = stmt.this; - self.visit_continue_stmt(ctx, this) + $this.visit_continue_stmt($ctx, this) }, Statement::Synchronous(stmt) => { let this = stmt.this; - self.visit_synchronous_stmt(ctx, this) + $this.visit_synchronous_stmt($ctx, this) }, - Statement::EndSynchronous(_stmt) => Ok(()), + Statement::EndSynchronous(_stmt) => $default_return, Statement::Fork(stmt) => { let this = stmt.this; - self.visit_fork_stmt(ctx, this) + $this.visit_fork_stmt($ctx, this) }, - Statement::EndFork(_stmt) => Ok(()), + Statement::EndFork(_stmt) => $default_return, Statement::Select(stmt) => { let this = stmt.this; - self.visit_select_stmt(ctx, this) + $this.visit_select_stmt($ctx, this) }, - Statement::EndSelect(_stmt) => Ok(()), + Statement::EndSelect(_stmt) => $default_return, Statement::Return(stmt) => { let this = stmt.this; - self.visit_return_stmt(ctx, this) + $this.visit_return_stmt($ctx, this) }, Statement::Goto(stmt) => { let this = stmt.this; - self.visit_goto_stmt(ctx, this) + $this.visit_goto_stmt($ctx, this) }, Statement::New(stmt) => { let this = stmt.this; - self.visit_new_stmt(ctx, this) + $this.visit_new_stmt($ctx, this) }, Statement::Expression(stmt) => { let this = stmt.this; - self.visit_expr_stmt(ctx, this) + $this.visit_expr_stmt($ctx, this) + } + } + }; +} + +macro_rules! visitor_recursive_local_impl { + ($this:expr, $local:expr, $ctx:expr) => { + match $local { + LocalStatement::Channel(local) => { + let this = local.this; + $this.visit_local_channel_stmt($ctx, this) + }, + LocalStatement::Memory(local) => { + let this = local.this; + $this.visit_local_memory_stmt($ctx, this) } } } +} - fn visit_local_stmt(&mut self, ctx: &mut Ctx, id: LocalStatementId) -> VisitorResult { - match &ctx.heap[id] { - LocalStatement::Channel(stmt) => { - let this = stmt.this; - self.visit_local_channel_stmt(ctx, this) +macro_rules! visitor_recursive_definition_impl { + ($this:expr, $definition:expr, $ctx:expr) => { + match $definition { + Definition::Enum(def) => { + let def = def.this; + $this.visit_enum_definition($ctx, def) }, - LocalStatement::Memory(stmt) => { - let this = stmt.this; - self.visit_local_memory_stmt(ctx, this) + Definition::Union(def) => { + let def = def.this; + $this.visit_union_definition($ctx, def) + }, + Definition::Struct(def) => { + let def = def.this; + $this.visit_struct_definition($ctx, def) + }, + Definition::Procedure(def) => { + let def = def.this; + $this.visit_procedure_definition($ctx, def) }, } } +} - // --- enum variant handling - fn visit_block_stmt(&mut self, _ctx: &mut Ctx, _id: BlockStatementId) -> VisitorResult { Ok(()) } - fn visit_local_memory_stmt(&mut self, _ctx: &mut Ctx, _id: MemoryStatementId) -> VisitorResult { Ok(()) } - fn visit_local_channel_stmt(&mut self, _ctx: &mut Ctx, _id: ChannelStatementId) -> VisitorResult { Ok(()) } - fn visit_labeled_stmt(&mut self, _ctx: &mut Ctx, _id: LabeledStatementId) -> VisitorResult { Ok(()) } - fn visit_if_stmt(&mut self, _ctx: &mut Ctx, _id: IfStatementId) -> VisitorResult { Ok(()) } - fn visit_while_stmt(&mut self, _ctx: &mut Ctx, _id: WhileStatementId) -> VisitorResult { Ok(()) } - fn visit_break_stmt(&mut self, _ctx: &mut Ctx, _id: BreakStatementId) -> VisitorResult { Ok(()) } - fn visit_continue_stmt(&mut self, _ctx: &mut Ctx, _id: ContinueStatementId) -> VisitorResult { Ok(()) } - fn visit_synchronous_stmt(&mut self, _ctx: &mut Ctx, _id: SynchronousStatementId) -> VisitorResult { Ok(()) } - fn visit_fork_stmt(&mut self, _ctx: &mut Ctx, _id: ForkStatementId) -> VisitorResult { Ok(()) } - fn visit_select_stmt(&mut self, _ctx: &mut Ctx, _id: SelectStatementId) -> VisitorResult { Ok(()) } - fn visit_return_stmt(&mut self, _ctx: &mut Ctx, _id: ReturnStatementId) -> VisitorResult { Ok(()) } - fn visit_goto_stmt(&mut self, _ctx: &mut Ctx, _id: GotoStatementId) -> VisitorResult { Ok(()) } - fn visit_new_stmt(&mut self, _ctx: &mut Ctx, _id: NewStatementId) -> VisitorResult { Ok(()) } - fn visit_expr_stmt(&mut self, _ctx: &mut Ctx, _id: ExpressionStatementId) -> VisitorResult { Ok(()) } - - // Expressions - // --- enum matching - fn visit_expr(&mut self, ctx: &mut Ctx, id: ExpressionId) -> VisitorResult { - match &ctx.heap[id] { +macro_rules! visitor_recursive_expression_impl { + ($this:expr, $expression:expr, $ctx:expr) => { + match $expression { Expression::Assignment(expr) => { let this = expr.this; - self.visit_assignment_expr(ctx, this) + $this.visit_assignment_expr($ctx, this) }, Expression::Binding(expr) => { let this = expr.this; - self.visit_binding_expr(ctx, this) - } + $this.visit_binding_expr($ctx, this) + }, Expression::Conditional(expr) => { let this = expr.this; - self.visit_conditional_expr(ctx, this) - } + $this.visit_conditional_expr($ctx, this) + }, Expression::Binary(expr) => { let this = expr.this; - self.visit_binary_expr(ctx, this) - } + $this.visit_binary_expr($ctx, this) + }, Expression::Unary(expr) => { let this = expr.this; - self.visit_unary_expr(ctx, this) - } + $this.visit_unary_expr($ctx, this) + }, Expression::Indexing(expr) => { let this = expr.this; - self.visit_indexing_expr(ctx, this) - } + $this.visit_indexing_expr($ctx, this) + }, Expression::Slicing(expr) => { let this = expr.this; - self.visit_slicing_expr(ctx, this) - } + $this.visit_slicing_expr($ctx, this) + }, Expression::Select(expr) => { let this = expr.this; - self.visit_select_expr(ctx, this) - } + $this.visit_select_expr($ctx, this) + }, Expression::Literal(expr) => { let this = expr.this; - self.visit_literal_expr(ctx, this) - } + $this.visit_literal_expr($ctx, this) + }, Expression::Cast(expr) => { let this = expr.this; - self.visit_cast_expr(ctx, this) - } + $this.visit_cast_expr($ctx, this) + }, Expression::Call(expr) => { let this = expr.this; - self.visit_call_expr(ctx, this) - } + $this.visit_call_expr($ctx, this) + }, Expression::Variable(expr) => { let this = expr.this; - self.visit_variable_expr(ctx, this) - } + $this.visit_variable_expr($ctx, this) + }, + } + }; +} + +/// Visitor is a generic trait that will fully walk the AST. The default +/// implementation of the visitors is to not recurse. The exception is the +/// top-level `visit_definition`, `visit_stmt` and `visit_expr` methods, which +/// call the appropriate visitor function. +pub(crate) trait Visitor { + // Entry point + fn visit_module(&mut self, ctx: &mut Ctx) -> VisitorResult { + let mut def_index = 0; + let module_root_id = ctx.modules[ctx.module_idx].root_id; + loop { + let definition_id = { + let root = &ctx.heap[module_root_id]; + if def_index >= root.definitions.len() { + return Ok(()) + } + + root.definitions[def_index] + }; + + self.visit_definition(ctx, definition_id)?; + def_index += 1; } } + // Definitions + // --- enum matching + fn visit_definition(&mut self, ctx: &mut Ctx, id: DefinitionId) -> VisitorResult { + return visitor_recursive_definition_impl!(self, &ctx.heap[id], ctx); + } + + // --- enum variant handling + fn visit_enum_definition(&mut self, _ctx: &mut Ctx, _id: EnumDefinitionId) -> VisitorResult { Ok(()) } + fn visit_union_definition(&mut self, _ctx: &mut Ctx, _id: UnionDefinitionId) -> VisitorResult{ Ok(()) } + fn visit_struct_definition(&mut self, _ctx: &mut Ctx, _id: StructDefinitionId) -> VisitorResult { Ok(()) } + fn visit_procedure_definition(&mut self, _ctx: &mut Ctx, _id: ProcedureDefinitionId) -> VisitorResult { Ok(()) } + + // Statements + // --- enum matching + fn visit_stmt(&mut self, ctx: &mut Ctx, id: StatementId) -> VisitorResult { + return visitor_recursive_statement_impl!(self, &ctx.heap[id], ctx, Ok(())); + } + + fn visit_local_stmt(&mut self, ctx: &mut Ctx, id: LocalStatementId) -> VisitorResult { + return visitor_recursive_local_impl!(self, &ctx.heap[id], ctx); + } + + // --- enum variant handling + fn visit_block_stmt(&mut self, _ctx: &mut Ctx, _id: BlockStatementId) -> VisitorResult { Ok(()) } + fn visit_local_memory_stmt(&mut self, _ctx: &mut Ctx, _id: MemoryStatementId) -> VisitorResult { Ok(()) } + fn visit_local_channel_stmt(&mut self, _ctx: &mut Ctx, _id: ChannelStatementId) -> VisitorResult { Ok(()) } + fn visit_labeled_stmt(&mut self, _ctx: &mut Ctx, _id: LabeledStatementId) -> VisitorResult { Ok(()) } + fn visit_if_stmt(&mut self, _ctx: &mut Ctx, _id: IfStatementId) -> VisitorResult { Ok(()) } + fn visit_while_stmt(&mut self, _ctx: &mut Ctx, _id: WhileStatementId) -> VisitorResult { Ok(()) } + fn visit_break_stmt(&mut self, _ctx: &mut Ctx, _id: BreakStatementId) -> VisitorResult { Ok(()) } + fn visit_continue_stmt(&mut self, _ctx: &mut Ctx, _id: ContinueStatementId) -> VisitorResult { Ok(()) } + fn visit_synchronous_stmt(&mut self, _ctx: &mut Ctx, _id: SynchronousStatementId) -> VisitorResult { Ok(()) } + fn visit_fork_stmt(&mut self, _ctx: &mut Ctx, _id: ForkStatementId) -> VisitorResult { Ok(()) } + fn visit_select_stmt(&mut self, _ctx: &mut Ctx, _id: SelectStatementId) -> VisitorResult { Ok(()) } + fn visit_return_stmt(&mut self, _ctx: &mut Ctx, _id: ReturnStatementId) -> VisitorResult { Ok(()) } + fn visit_goto_stmt(&mut self, _ctx: &mut Ctx, _id: GotoStatementId) -> VisitorResult { Ok(()) } + fn visit_new_stmt(&mut self, _ctx: &mut Ctx, _id: NewStatementId) -> VisitorResult { Ok(()) } + fn visit_expr_stmt(&mut self, _ctx: &mut Ctx, _id: ExpressionStatementId) -> VisitorResult { Ok(()) } + + // Expressions + // --- enum matching + fn visit_expr(&mut self, ctx: &mut Ctx, id: ExpressionId) -> VisitorResult { + return visitor_recursive_expression_impl!(self, &ctx.heap[id], ctx); + } + fn visit_assignment_expr(&mut self, _ctx: &mut Ctx, _id: AssignmentExpressionId) -> VisitorResult { Ok(()) } fn visit_binding_expr(&mut self, _ctx: &mut Ctx, _id: BindingExpressionId) -> VisitorResult { Ok(()) } fn visit_conditional_expr(&mut self, _ctx: &mut Ctx, _id: ConditionalExpressionId) -> VisitorResult { Ok(()) } diff --git a/src/protocol/tests/parser_monomorphs.rs b/src/protocol/tests/parser_monomorphs.rs index 2926b727bfa3360ea90779f6de4d5f43ab281694..13098bf77ad99174c738115acb30cb71423108a0 100644 --- a/src/protocol/tests/parser_monomorphs.rs +++ b/src/protocol/tests/parser_monomorphs.rs @@ -57,7 +57,7 @@ fn test_enum_monomorphs() { // Note for reader: because the enum doesn't actually use the polymorphic // variable, we expect to have 1 monomorph: the type only has to be laid - // out once. + // out once. @Deduplication Tester::new_single_source_expect_ok( "single polymorph", " @@ -71,8 +71,10 @@ fn test_enum_monomorphs() { } " ).for_enum("Answer", |e| { e - .assert_num_monomorphs(1) - .assert_has_monomorph("Answer"); + .assert_num_monomorphs(3) + .assert_has_monomorph("Answer") + .assert_has_monomorph("Answer") + .assert_has_monomorph("Answer>>"); }); } diff --git a/src/protocol/tests/parser_validation.rs b/src/protocol/tests/parser_validation.rs index f6cb386849c640aa79f36a3f3d96477e9161917b..a3fc6900e3388a8fff3c4f2dc2fa6d2c91b93097 100644 --- a/src/protocol/tests/parser_validation.rs +++ b/src/protocol/tests/parser_validation.rs @@ -347,7 +347,7 @@ fn test_incorrect_union_instance() { " ).error(|e| { e .assert_occurs_at(0, "Foo::A") - .assert_msg_has(0, "failed to fully resolve") + .assert_msg_has(0, "failed to resolve") .assert_occurs_at(1, "false") .assert_msg_has(1, "has been resolved to 's32'") .assert_msg_has(1, "has been resolved to 'bool'"); @@ -734,7 +734,35 @@ fn test_incorrect_goto_statement() { .assert_occurs_at(0, "goto exit;").assert_msg_has(0, "not escape the surrounding sync") .assert_occurs_at(1, "exit: u32 v").assert_msg_has(1, "target of the goto") .assert_occurs_at(2, "sync {").assert_msg_has(2, "jump past this"); - }) + }); + + Tester::new_single_source_expect_err( + "goto jumping to select case", + "primitive f(in i) { + sync select { + hello: auto a = get(i) -> i += 1 + } + goto hello; + }" + ).error(|e| { e + .assert_msg_has(0, "expected '->'"); + }); + + Tester::new_single_source_expect_err( + "goto jumping into select case skipping variable", + "primitive f(in i) { + goto waza; + sync select { + auto a = get(i) -> { + waza: a += 1; + } + } + }" + ).error(|e| { e + .assert_num(1) + .assert_msg_has(0, "not find this label") + .assert_occurs_at(0, "waza;"); + }); } #[test] diff --git a/src/protocol/tests/utils.rs b/src/protocol/tests/utils.rs index 63ff0f95929262c11c44f701f0f31015589658dd..7dff07dacf06403751b5379528373bca3cd94c92 100644 --- a/src/protocol/tests/utils.rs +++ b/src/protocol/tests/utils.rs @@ -218,7 +218,7 @@ impl AstOkTester { pub(crate) fn for_function(self, name: &str, f: F) -> Self { let mut found = false; for definition in self.heap.definitions.iter() { - if let Definition::Function(definition) = definition { + if let Definition::Procedure(definition) = definition { if definition.identifier.value.as_str() != name { continue; } @@ -296,8 +296,8 @@ impl<'a> StructTester<'a> { pub(crate) fn assert_size_alignment(mut self, monomorph: &str, size: usize, alignment: usize) -> Self { self = self.assert_has_monomorph(monomorph); let (mono_idx, _) = has_monomorph(self.ctx, self.ast_def.this.upcast(), monomorph); - let mono_idx = mono_idx.unwrap(); - let mono = self.ctx.types.get_monomorph(mono_idx); + let type_id = mono_idx.unwrap(); + let mono = self.ctx.types.get_monomorph(type_id); assert!( mono.size == size && mono.alignment == alignment, @@ -507,25 +507,23 @@ impl<'a> UnionTester<'a> { pub(crate) struct FunctionTester<'a> { ctx: TestCtx<'a>, - def: &'a FunctionDefinition, + def: &'a ProcedureDefinition, } impl<'a> FunctionTester<'a> { - fn new(ctx: TestCtx<'a>, def: &'a FunctionDefinition) -> Self { + fn new(ctx: TestCtx<'a>, def: &'a ProcedureDefinition) -> Self { Self{ ctx, def } } pub(crate) fn for_variable(self, name: &str, f: F) -> Self { // Seek through the blocks in order to find the variable - let wrapping_block_id = seek_stmt( - self.ctx.heap, self.def.body.upcast(), - &|stmt| { - if let Statement::Block(block) = stmt { - for local_id in &block.locals { - let var = &self.ctx.heap[*local_id]; - if var.identifier.value.as_str() == name { - return true; - } + let wrapping_scope = seek_scope( + self.ctx.heap, self.def.scope, + &|scope| { + for variable_id in scope.variables.iter().copied() { + let var = &self.ctx.heap[variable_id]; + if var.identifier.value.as_str() == name { + return true; } } @@ -534,13 +532,13 @@ impl<'a> FunctionTester<'a> { ); let mut found_local_id = None; - if let Some(block_id) = wrapping_block_id { - // Found the right block, find the variable inside the block again - let block_stmt = self.ctx.heap[block_id].as_block(); - for local_id in &block_stmt.locals { - let var = &self.ctx.heap[*local_id]; - if var.identifier.value.as_str() == name { - found_local_id = Some(*local_id); + if let Some(scope_id) = wrapping_scope { + // Found the right scope, find the variable inside the block again + let scope = &self.ctx.heap[scope_id]; + for variable_id in scope.variables.iter().copied() { + let variable = &self.ctx.heap[variable_id]; + if variable.identifier.value.as_str() == name { + found_local_id = Some(variable_id); } } } @@ -702,11 +700,11 @@ impl<'a> FunctionTester<'a> { use crate::protocol::*; // Assuming the function is not polymorphic - let definition_id = self.def.this.upcast(); + let definition_id = self.def.this; let func_type = [ConcreteTypePart::Function(definition_id, 0)]; - let mono_index = self.ctx.types.get_procedure_monomorph_index(&definition_id, &func_type).unwrap(); + let mono_index = self.ctx.types.get_procedure_monomorph_type_id(&definition_id.upcast(), &func_type).unwrap(); - let mut prompt = Prompt::new(&self.ctx.types, &self.ctx.heap, self.def.this.upcast(), mono_index, ValueGroup::new_stack(Vec::new())); + let mut prompt = Prompt::new(&self.ctx.types, &self.ctx.heap, definition_id, mono_index, ValueGroup::new_stack(Vec::new())); let mut call_context = FakeRunContext{}; loop { let result = prompt.step(&self.ctx.types, &self.ctx.heap, &self.ctx.modules, &mut call_context); @@ -750,8 +748,11 @@ impl<'a> VariableTester<'a> { pub(crate) fn assert_concrete_type(self, expected: &str) -> Self { // Lookup concrete type in type table - let mono_data = get_procedure_monomorph(&self.ctx.heap, &self.ctx.types, self.definition_id); - let concrete_type = &mono_data.expr_data[self.var_expr.unique_id_in_definition as usize].expr_type; + let mono_proc = get_procedure_monomorph(&self.ctx.heap, &self.ctx.types, self.definition_id); + let mono_index = mono_proc.monomorph_index; + let mono_data = &self.ctx.heap[self.definition_id].as_procedure().monomorphs[mono_index as usize]; + let expr_info = &mono_data.expr_info[self.var_expr.type_index as usize]; + let concrete_type = &self.ctx.types.get_monomorph(expr_info.type_id).concrete_type; // Serialize and check let serialized = concrete_type.display_name(self.ctx.heap); @@ -784,9 +785,11 @@ impl<'a> ExpressionTester<'a> { pub(crate) fn assert_concrete_type(self, expected: &str) -> Self { // Lookup concrete type - let mono_data = get_procedure_monomorph(&self.ctx.heap, &self.ctx.types, self.definition_id); - let expr_index = self.expr.get_unique_id_in_definition(); - let concrete_type = &mono_data.expr_data[expr_index as usize].expr_type; + let mono_proc = get_procedure_monomorph(&self.ctx.heap, &self.ctx.types, self.definition_id); + let mono_index = mono_proc.monomorph_index; + let mono_data = &self.ctx.heap[self.definition_id].as_procedure().monomorphs[mono_index as usize]; + let expr_info = &mono_data.expr_info[self.expr.type_index() as usize]; + let concrete_type = &self.ctx.types.get_monomorph(expr_info.type_id).concrete_type; // Serialize and check type let serialized = concrete_type.display_name(self.ctx.heap); @@ -808,18 +811,15 @@ impl<'a> ExpressionTester<'a> { } fn get_procedure_monomorph<'a>(heap: &Heap, types: &'a TypeTable, definition_id: DefinitionId) -> &'a ProcedureMonomorph { - let ast_definition = &heap[definition_id]; - let func_type = if ast_definition.is_function() { - [ConcreteTypePart::Function(definition_id, 0)] - } else if ast_definition.is_component() { - [ConcreteTypePart::Component(definition_id, 0)] + let ast_definition = heap[definition_id].as_procedure(); + let func_type = if ast_definition.kind == ProcedureKind::Function { + [ConcreteTypePart::Function(ast_definition.this, 0)] } else { - assert!(false); - unreachable!() + [ConcreteTypePart::Component(ast_definition.this, 0)] }; - let mono_index = types.get_procedure_monomorph_index(&definition_id, &func_type).unwrap(); - let mono_data = types.get_procedure_monomorph(mono_index); + let mono_index = types.get_procedure_monomorph_type_id(&definition_id, &func_type).unwrap(); + let mono_data = types.get_monomorph(mono_index).variant.as_procedure(); mono_data } @@ -928,12 +928,16 @@ fn has_equal_num_monomorphs(ctx: TestCtx, num: usize, definition_id: DefinitionI // Again: inefficient, but its testing code let mut num_on_type = 0; - for mono in &ctx.types.mono_lookup.monomorphs { + for mono in &ctx.types.mono_types { match &mono.concrete_type.parts[0] { - ConcreteTypePart::Instance(def_id, _) | + ConcreteTypePart::Instance(def_id, _) => { + if *def_id == definition_id { + num_on_type += 1; + } + } ConcreteTypePart::Function(def_id, _) | ConcreteTypePart::Component(def_id, _) => { - if *def_id == definition_id { + if def_id.upcast() == definition_id { num_on_type += 1; } }, @@ -944,13 +948,13 @@ fn has_equal_num_monomorphs(ctx: TestCtx, num: usize, definition_id: DefinitionI (num_on_type == num, num_on_type) } -fn has_monomorph(ctx: TestCtx, definition_id: DefinitionId, serialized_monomorph: &str) -> (Option, String) { +fn has_monomorph(ctx: TestCtx, definition_id: DefinitionId, serialized_monomorph: &str) -> (Option, String) { // Note: full_buffer is just for error reporting let mut full_buffer = String::new(); let mut has_match = None; full_buffer.push('['); - let mut append_to_full_buffer = |concrete_type: &ConcreteType, mono_idx: usize| { + let mut append_to_full_buffer = |concrete_type: &ConcreteType, type_id: TypeId| { if full_buffer.len() != 1 { full_buffer.push_str(", "); } @@ -959,22 +963,22 @@ fn has_monomorph(ctx: TestCtx, definition_id: DefinitionId, serialized_monomorph let first_idx = full_buffer.len(); full_buffer.push_str(concrete_type.display_name(ctx.heap).as_str()); if &full_buffer[first_idx..] == serialized_monomorph { - has_match = Some(mono_idx as i32); + has_match = Some(type_id); } full_buffer.push('"'); }; // Bit wasteful, but this is (temporary?) testing code: - for (mono_idx, mono) in ctx.types.mono_lookup.monomorphs.iter().enumerate() { + for (_mono_idx, mono) in ctx.types.mono_types.iter().enumerate() { let got_definition_id = match &mono.concrete_type.parts[0] { - ConcreteTypePart::Instance(v, _) | + ConcreteTypePart::Instance(v, _) => *v, ConcreteTypePart::Function(v, _) | - ConcreteTypePart::Component(v, _) => *v, + ConcreteTypePart::Component(v, _) => v.upcast(), _ => DefinitionId::new_invalid(), }; if got_definition_id == definition_id { - append_to_full_buffer(&mono.concrete_type, mono_idx); + append_to_full_buffer(&mono.concrete_type, mono.type_id); } } @@ -1098,23 +1102,36 @@ fn seek_stmt bool>(heap: &Heap, start: StatementId, f: &F) }, Statement::Labeled(stmt) => seek_stmt(heap, stmt.body, f), Statement::If(stmt) => { - if let Some(id) = seek_stmt(heap, stmt.true_body.upcast(), f) { + if let Some(id) = seek_stmt(heap, stmt.true_case.body, f) { return Some(id); - } else if let Some(false_body) = stmt.false_body { - if let Some(id) = seek_stmt(heap, false_body.upcast(), f) { + } else if let Some(false_body) = stmt.false_case { + if let Some(id) = seek_stmt(heap, false_body.body, f) { return Some(id); } } None }, - Statement::While(stmt) => seek_stmt(heap, stmt.body.upcast(), f), - Statement::Synchronous(stmt) => seek_stmt(heap, stmt.body.upcast(), f), + Statement::While(stmt) => seek_stmt(heap, stmt.body, f), + Statement::Synchronous(stmt) => seek_stmt(heap, stmt.body, f), _ => None }; matched } +fn seek_scope bool>(heap: &Heap, start: ScopeId, f: &F) -> Option { + let scope = &heap[start]; + if f(scope) { return Some(start); } + + for child_scope_id in scope.nested.iter().copied() { + if let Some(result) = seek_scope(heap, child_scope_id, f) { + return Some(result); + } + } + + return None; +} + fn seek_expr_in_expr bool>(heap: &Heap, start: ExpressionId, f: &F) -> Option { let expr = &heap[start]; if f(expr) { return Some(start); } @@ -1215,9 +1232,9 @@ fn seek_expr_in_stmt bool>(heap: &Heap, start: StatementId Statement::If(stmt) => { None .or_else(|| seek_expr_in_expr(heap, stmt.test, f)) - .or_else(|| seek_expr_in_stmt(heap, stmt.true_body.upcast(), f)) - .or_else(|| if let Some(false_body) = stmt.false_body { - seek_expr_in_stmt(heap, false_body.upcast(), f) + .or_else(|| seek_expr_in_stmt(heap, stmt.true_case.body, f)) + .or_else(|| if let Some(false_body) = stmt.false_case { + seek_expr_in_stmt(heap, false_body.body, f) } else { None }) @@ -1225,10 +1242,10 @@ fn seek_expr_in_stmt bool>(heap: &Heap, start: StatementId Statement::While(stmt) => { None .or_else(|| seek_expr_in_expr(heap, stmt.test, f)) - .or_else(|| seek_expr_in_stmt(heap, stmt.body.upcast(), f)) + .or_else(|| seek_expr_in_stmt(heap, stmt.body, f)) }, Statement::Synchronous(stmt) => { - seek_expr_in_stmt(heap, stmt.body.upcast(), f) + seek_expr_in_stmt(heap, stmt.body, f) }, Statement::Return(stmt) => { for expr_id in &stmt.expressions { @@ -1250,23 +1267,10 @@ fn seek_expr_in_stmt bool>(heap: &Heap, start: StatementId struct FakeRunContext{} impl RunContext for FakeRunContext { - fn performed_put(&mut self, _port: PortId) -> bool { - unreachable!("'put' called in compiler testing code") - } - - fn performed_get(&mut self, _port: PortId) -> Option { - unreachable!("'get' called in compiler testing code") - } - - fn fires(&mut self, _port: PortId) -> Option { - unreachable!("'fires' called in compiler testing code") - } - - fn performed_fork(&mut self) -> Option { - unreachable!("'fork' called in compiler testing code") - } - - fn created_channel(&mut self) -> Option<(Value, Value)> { - unreachable!("channel created in compiler testing code") - } + fn performed_put(&mut self, _port: PortId) -> bool { unreachable!() } + fn performed_get(&mut self, _port: PortId) -> Option { unreachable!() } + fn fires(&mut self, _port: PortId) -> Option { unreachable!() } + fn performed_fork(&mut self) -> Option { unreachable!() } + fn created_channel(&mut self) -> Option<(Value, Value)> { unreachable!() } + fn performed_select_wait(&mut self) -> Option { unreachable!() } } \ No newline at end of file diff --git a/src/random.rs b/src/random.rs new file mode 100644 index 0000000000000000000000000000000000000000..46fceaca3874a7b8080cf6b3ce8dc85d38568927 --- /dev/null +++ b/src/random.rs @@ -0,0 +1,39 @@ +/** + * random.rs + * + * Simple wrapper over a random number generator. Put here so that we can have + * a feature flag for particular forms of randomness. For now we'll use pseudo- + * randomness since that will help debugging. + */ + +use rand::{RngCore, SeedableRng}; +use rand_pcg; + +pub(crate) struct Random { + rng: rand_pcg::Lcg64Xsh32, +} + +impl Random { + pub(crate) fn new() -> Self { + use std::time::SystemTime; + + let now = SystemTime::now(); + let elapsed = match now.duration_since(SystemTime::UNIX_EPOCH) { + Ok(elapsed) => elapsed, + Err(err) => err.duration(), + }; + + let elapsed = elapsed.as_nanos(); + let seed = elapsed.to_le_bytes(); + + return Self::new_seeded(seed); + } + + pub(crate) fn new_seeded(seed: [u8; 16]) -> Self { + return Self{ rng: rand_pcg::Pcg32::from_seed(seed) } + } + + pub(crate) fn get_u64(&mut self) -> u64 { + return self.rng.next_u64(); + } +} \ No newline at end of file diff --git a/src/runtime/connector.rs b/src/runtime/connector.rs index e9bce7685c94520e6df7cf873184512aee407e1f..2aabac130545b88c7fb292d41e42d6a7a12b7493 100644 --- a/src/runtime/connector.rs +++ b/src/runtime/connector.rs @@ -122,6 +122,8 @@ impl<'a> RunContext for ConnectorRunContext<'a>{ taken => unreachable!("prepared statement is '{:?}' during 'performed_fork()'", taken), }; } + + fn performed_select_wait(&mut self) -> Option { unreachable!() } } impl Connector for ConnectorPDL { @@ -463,7 +465,7 @@ impl ConnectorPDL { return ConnectorScheduling::Immediate; }, - EvalContinuation::NewComponent(definition_id, monomorph_idx, arguments) => { + EvalContinuation::NewComponent(definition_id, type_id, arguments) => { // Note: we're relinquishing ownership of ports. But because // we are in non-sync mode the scheduler will handle and check // port ownership transfer. @@ -473,7 +475,7 @@ impl ConnectorPDL { let new_prompt = Prompt::new( &sched_ctx.runtime.protocol_description.types, &sched_ctx.runtime.protocol_description.heap, - definition_id, monomorph_idx, arguments + definition_id, type_id, arguments ); let new_component = ConnectorPDL::new(new_prompt); comp_ctx.push_component(new_component, comp_ctx.workspace_ports.clone()); diff --git a/src/runtime/consensus.rs b/src/runtime/consensus.rs index 1147f447ef8a25ef9a8b8e0754c36b4527efdfc0..616fc07fbfa1db8177754241900e984f7aa1f667 100644 --- a/src/runtime/consensus.rs +++ b/src/runtime/consensus.rs @@ -328,13 +328,13 @@ impl Consensus { let branch = &mut self.branch_annotations[branch_id.index as usize]; let port_info = ctx.get_port_by_id(source_port_id).unwrap(); - if cfg!(debug_assertions) { + dbg_code!({ // Check for consistent mapping let port = branch.channel_mapping.iter() .find(|v| v.channel_id == port_info.channel_id) .unwrap(); debug_assert!(port.expected_firing == None || port.expected_firing == Some(true)); - } + }); // Check for ports that are being sent debug_assert!(self.workspace_ports.is_empty()); diff --git a/src/runtime2/component/component_context.rs b/src/runtime2/component/component_context.rs index 8ed701cd2cf1427747cb35dfab60dcf2b9fd4093..ecd0823e7b5b49bb3004d6273b0adc5c11692fbd 100644 --- a/src/runtime2/component/component_context.rs +++ b/src/runtime2/component/component_context.rs @@ -28,7 +28,7 @@ pub struct CompCtx { port_id_counter: u32, } -#[derive(Copy, Clone)] +#[derive(Copy, Clone, PartialEq, Eq)] pub struct LocalPortHandle(PortId); #[derive(Copy, Clone)] diff --git a/src/runtime2/component/component_pdl.rs b/src/runtime2/component/component_pdl.rs index 8fdb556af27fe536f23f6d19a592551b097e4450..14fd2fb0dd6670acda349042897d5b70743c1c47 100644 --- a/src/runtime2/component/component_pdl.rs +++ b/src/runtime2/component/component_pdl.rs @@ -1,5 +1,6 @@ +use crate::random::Random; use crate::protocol::*; -use crate::protocol::ast::DefinitionId; +use crate::protocol::ast::ProcedureDefinitionId; use crate::protocol::eval::{ PortId as EvalPortId, Prompt, ValueGroup, Value, @@ -24,6 +25,7 @@ pub enum ExecStmt { CreatedChannel((Value, Value)), PerformedPut, PerformedGet(ValueGroup), + PerformedSelectWait(u32), None, } @@ -78,6 +80,14 @@ impl RunContext for ExecCtx { _ => unreachable!(), } } + + fn performed_select_wait(&mut self) -> Option { + match self.stmt.take() { + ExecStmt::None => return None, + ExecStmt::PerformedSelectWait(selected_case) => Some(selected_case), + _v => unreachable!(), + } + } } #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -85,17 +95,136 @@ pub(crate) enum Mode { NonSync, // not in sync mode Sync, // in sync mode, can interact with other components SyncEnd, // awaiting a solution, i.e. encountered the end of the sync block - BlockedGet, - BlockedPut, + BlockedGet, // blocked because we need to receive a message on a particular port + BlockedPut, // component is blocked because the port is blocked + BlockedSelect, // waiting on message to complete the select statement StartExit, // temporary state: if encountered then we start the shutdown process BusyExit, // temporary state: waiting for Acks for all the closed ports Exit, // exiting: shutdown process started, now waiting until the reference count drops to 0 } +struct SelectCase { + involved_ports: Vec, +} + +// TODO: @Optimize, flatten cases into single array, have index-pointers to next case +struct SelectState { + cases: Vec, + next_case: u32, + num_cases: u32, + random: Random, + candidates_workspace: Vec, +} + +enum SelectDecision { + None, + Case(u32), // contains case index, should be passed along to PDL code +} + +type InboxMain = Vec>; + +impl SelectState { + fn new() -> Self { + return Self{ + cases: Vec::new(), + next_case: 0, + num_cases: 0, + random: Random::new(), + candidates_workspace: Vec::new(), + } + } + + fn handle_select_start(&mut self, num_cases: u32) { + self.cases.clear(); + self.next_case = 0; + self.num_cases = num_cases; + } + + /// Register a port as belonging to a particular case. As for correctness of + /// PDL code one cannot register the same port twice, this function might + /// return an error + fn register_select_case_port(&mut self, comp_ctx: &CompCtx, case_index: u32, _port_index: u32, port_id: PortId) -> Result<(), PortId> { + // Retrieve case and port handle + self.ensure_at_case(case_index); + let cur_case = &mut self.cases[case_index as usize]; + let port_handle = comp_ctx.get_port_handle(port_id); + debug_assert_eq!(cur_case.involved_ports.len(), _port_index as usize); + + // Make sure port wasn't added before, we disallow having the same port + // in the same select guard twice. + if cur_case.involved_ports.contains(&port_handle) { + return Err(port_id); + } + + cur_case.involved_ports.push(port_handle); + return Ok(()); + } + + /// Notification that all ports have been registered and we should now wait + /// until the appropriate messages have come in. + fn handle_select_waiting_point(&mut self, inbox: &InboxMain, comp_ctx: &CompCtx) -> SelectDecision { + if self.num_cases != self.next_case { + // This happens when there are >=1 select cases written at the end + // of the select block. + self.ensure_at_case(self.num_cases - 1); + } + + return self.has_decision(inbox, comp_ctx); + } + + fn handle_updated_inbox(&mut self, inbox: &InboxMain, comp_ctx: &CompCtx) -> SelectDecision { + return self.has_decision(inbox, comp_ctx); + } + + /// Internal helper, pushes empty cases inbetween last case and provided new + /// case index. + fn ensure_at_case(&mut self, new_case_index: u32) { + // Push an empty case for all intermediate cases that were not + // registered with a port. + debug_assert!(new_case_index >= self.next_case && new_case_index < self.num_cases); + for _ in self.next_case..new_case_index + 1 { + self.cases.push(SelectCase{ involved_ports: Vec::new() }); + } + self.next_case = new_case_index + 1; + } + + /// Checks if a decision can be reached + fn has_decision(&mut self, inbox: &InboxMain, comp_ctx: &CompCtx) -> SelectDecision { + self.candidates_workspace.clear(); + if self.cases.is_empty() { + // If there are no cases then we can immediately reach a "bogus + // decision". + return SelectDecision::Case(0); + } + + // Need to check for valid case + 'case_loop: for (case_index, case) in self.cases.iter().enumerate() { + for port_handle in case.involved_ports.iter().copied() { + let port_index = comp_ctx.get_port_index(port_handle); + if inbox[port_index].is_none() { + // Condition not satisfied + continue 'case_loop; + } + } + + // If here then the case guard is satisfied + self.candidates_workspace.push(case_index); + } + + if self.candidates_workspace.is_empty() { + return SelectDecision::None; + } else { + let candidate_index = self.random.get_u64() as usize % self.candidates_workspace.len(); + return SelectDecision::Case(self.candidates_workspace[candidate_index] as u32); + } + } +} + pub(crate) struct CompPDL { pub mode: Mode, pub mode_port: PortId, // when blocked on a port pub mode_value: ValueGroup, // when blocked on a put + select: SelectState, pub prompt: Prompt, pub control: ControlLayer, pub consensus: Consensus, @@ -105,7 +234,7 @@ pub(crate) struct CompPDL { // reserved per port. // Should be same length as the number of ports. Corresponding indices imply // message is intended for that port. - pub inbox_main: Vec>, + pub inbox_main: InboxMain, pub inbox_backup: Vec, } @@ -121,6 +250,7 @@ impl CompPDL { mode: Mode::NonSync, mode_port: PortId::new_invalid(), mode_value: ValueGroup::default(), + select: SelectState::new(), prompt: initial_state, control: ControlLayer::default(), consensus: Consensus::new(), @@ -168,7 +298,9 @@ impl CompPDL { // Depending on the mode don't do anything at all, take some special // actions, or fall through and run the PDL code. match self.mode { - Mode::NonSync | Mode::Sync => {}, + Mode::NonSync | Mode::Sync | Mode::BlockedSelect => { + // continue and run PDL code + }, Mode::SyncEnd | Mode::BlockedGet | Mode::BlockedPut => { return Ok(CompScheduling::Sleep); } @@ -230,6 +362,7 @@ impl CompPDL { }, EC::Put(port_id, value) => { debug_assert_eq!(self.mode, Mode::Sync); + sched_ctx.log(&format!("Putting value {:?}", value)); let port_id = port_id_from_eval(port_id); let port_handle = comp_ctx.get_port_handle(port_id); let port_info = comp_ctx.get_port(port_handle); @@ -245,6 +378,33 @@ impl CompPDL { return Ok(CompScheduling::Immediate); } }, + EC::SelectStart(num_cases, _num_ports) => { + debug_assert_eq!(self.mode, Mode::Sync); + self.select.handle_select_start(num_cases); + return Ok(CompScheduling::Requeue); + }, + EC::SelectRegisterPort(case_index, port_index, port_id) => { + debug_assert_eq!(self.mode, Mode::Sync); + let port_id = port_id_from_eval(port_id); + if let Err(_err) = self.select.register_select_case_port(comp_ctx, case_index, port_index, port_id) { + todo!("handle registering a port multiple times"); + } + return Ok(CompScheduling::Immediate); + }, + EC::SelectWait => { + debug_assert_eq!(self.mode, Mode::Sync); + let select_decision = self.select.handle_select_waiting_point(&self.inbox_main, comp_ctx); + if let SelectDecision::Case(case_index) = select_decision { + // Reached a conclusion, so we can continue immediately + self.exec_ctx.stmt = ExecStmt::PerformedSelectWait(case_index); + self.mode = Mode::Sync; + return Ok(CompScheduling::Immediate); + } else { + // No decision yet + self.mode = Mode::BlockedSelect; + return Ok(CompScheduling::Sleep); + } + }, // Results that can be returned outside of sync mode EC::ComponentTerminated => { self.mode = Mode::StartExit; // next call we'll take care of the exit @@ -255,11 +415,11 @@ impl CompPDL { self.handle_sync_start(sched_ctx, comp_ctx); return Ok(CompScheduling::Immediate); }, - EC::NewComponent(definition_id, monomorph_idx, arguments) => { + EC::NewComponent(definition_id, type_id, arguments) => { debug_assert_eq!(self.mode, Mode::NonSync); self.create_component_and_transfer_ports( sched_ctx, comp_ctx, - definition_id, monomorph_idx, arguments + definition_id, type_id, arguments ); return Ok(CompScheduling::Requeue); }, @@ -311,7 +471,7 @@ impl CompPDL { /// Handles decision from the consensus round. This will cause a change in /// the internal `Mode`, such that the next call to `run` can take the /// appropriate next steps. - fn handle_sync_decision(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx, decision: SyncRoundDecision) { + fn handle_sync_decision(&mut self, sched_ctx: &SchedulerCtx, _comp_ctx: &mut CompCtx, decision: SyncRoundDecision) { sched_ctx.log(&format!("Handling sync decision: {:?} (in mode {:?})", decision, self.mode)); let is_success = match decision { SyncRoundDecision::None => { @@ -389,6 +549,12 @@ impl CompPDL { // We were indeed blocked self.mode = Mode::Sync; self.mode_port = PortId::new_invalid(); + } else if self.mode == Mode::BlockedSelect { + let select_decision = self.select.handle_updated_inbox(&self.inbox_main, comp_ctx); + if let SelectDecision::Case(case_index) = select_decision { + self.exec_ctx.stmt = ExecStmt::PerformedSelectWait(case_index); + self.mode = Mode::Sync; + } } return; @@ -598,7 +764,7 @@ impl CompPDL { fn create_component_and_transfer_ports( &mut self, sched_ctx: &SchedulerCtx, creator_ctx: &mut CompCtx, - definition_id: DefinitionId, monomorph_index: i32, mut arguments: ValueGroup + definition_id: ProcedureDefinitionId, type_id: TypeId, mut arguments: ValueGroup ) { struct PortPair{ creator_handle: LocalPortHandle, @@ -692,7 +858,7 @@ impl CompPDL { // to message exchanges between remote peers. let prompt = Prompt::new( &sched_ctx.runtime.protocol.types, &sched_ctx.runtime.protocol.heap, - definition_id, monomorph_index, arguments, + definition_id, type_id, arguments, ); let component = CompPDL::new(prompt, port_id_pairs.len()); let (created_key, component) = sched_ctx.runtime.finish_create_pdl_component( diff --git a/src/runtime2/component/consensus.rs b/src/runtime2/component/consensus.rs index 4683ff164f11a5d64815a14a18ef5d94395ad24e..666ecf3c637103d1c7646b01252373c244637536 100644 --- a/src/runtime2/component/consensus.rs +++ b/src/runtime2/component/consensus.rs @@ -30,6 +30,8 @@ enum Mode { NonSync, SyncBusy, SyncAwaitingSolution, + SelectBusy, + SelectWait, } struct SolutionCombiner { @@ -99,7 +101,7 @@ impl SolutionCombiner { /// Combines the currently stored global solution (if any) with the newly /// provided local solution. Make sure to check the `has_decision` return /// value afterwards. - fn combine_with_local_solution(&mut self, comp_id: CompId, solution: SyncLocalSolution) { + fn combine_with_local_solution(&mut self, _comp_id: CompId, solution: SyncLocalSolution) { debug_assert_ne!(self.solution.decision, SyncRoundDecision::Solution); // Combine partial solution with the local solution entries @@ -422,6 +424,19 @@ impl Consensus { self.highest_id = header.highest_id; for peer in comp_ctx.iter_peers() { if peer.id == header.sending_id { + continue; // do not send to sender: it has the higher ID + } + + // also: only send if we received a message in this round + let mut performed_communication = false; // TODO: Revise, temporary fix + for port in self.ports.iter() { + if port.peer_comp_id == peer.id && port.mapping.is_some() { + performed_communication = true; + break; + } + } + + if !performed_communication { continue; } diff --git a/src/runtime2/store/component.rs b/src/runtime2/store/component.rs index 0370e5956540d36a4784546678e35ab0dca7408f..65e0bbc75d1999b762b56ff7a78cc44f5cb36c2c 100644 --- a/src/runtime2/store/component.rs +++ b/src/runtime2/store/component.rs @@ -534,7 +534,7 @@ mod tests { } else { // Must destroy let stored_index = new_value as usize % stored.len(); - let (el_index, el_value) = stored.remove(stored_index); + let (el_index, _el_value) = stored.remove(stored_index); store.destroy(el_index); } } diff --git a/src/runtime2/tests/mod.rs b/src/runtime2/tests/mod.rs index ce820c8a0167bf00581b238e07141be8c6703995..2a481ff71dd208cea4ea0b3705c1192256b281a3 100644 --- a/src/runtime2/tests/mod.rs +++ b/src/runtime2/tests/mod.rs @@ -25,7 +25,7 @@ fn test_component_creation() { ").expect("compilation"); let rt = Runtime::new(1, true, pd); - for i in 0..20 { + for _i in 0..20 { create_component(&rt, "", "nothing_at_all", no_args()); } } @@ -82,4 +82,92 @@ fn test_component_communication() { }").expect("compilation"); let rt = Runtime::new(3, true, pd); create_component(&rt, "", "constructor", no_args()); +} + +#[test] +fn test_simple_select() { + let pd = ProtocolDescription::parse(b" + func infinite_assert(T val, T expected) -> () { + while (val != expected) { print(\"nope!\"); } + return (); + } + + primitive receiver(in in_a, in in_b, u32 num_sends) { + auto num_from_a = 0; + auto num_from_b = 0; + while (num_from_a + num_from_b < 2 * num_sends) { + sync select { + auto v = get(in_a) -> { + print(\"got something from A\"); + auto _ = infinite_assert(v, num_from_a); + num_from_a += 1; + } + auto v = get(in_b) -> { + print(\"got something from B\"); + auto _ = infinite_assert(v, num_from_b); + num_from_b += 1; + } + } + } + } + + primitive sender(out tx, u32 num_sends) { + auto index = 0; + while (index < num_sends) { + sync { + put(tx, index); + index += 1; + } + } + } + + composite constructor() { + auto num_sends = 15; + channel tx_a -> rx_a; + channel tx_b -> rx_b; + new sender(tx_a, num_sends); + new receiver(rx_a, rx_b, num_sends); + new sender(tx_b, num_sends); + } + ").expect("compilation"); + let rt = Runtime::new(3, false, pd); + create_component(&rt, "", "constructor", no_args()); +} + +#[test] +fn test_unguarded_select() { + let pd = ProtocolDescription::parse(b" + primitive constructor_outside_select() { + u32 index = 0; + while (index < 5) { + sync select { auto v = () -> print(\"hello\"); } + index += 1; + } + } + + primitive constructor_inside_select() { + u32 index = 0; + while (index < 5) { + sync select { auto v = () -> index += 1; } + } + } + ").expect("compilation"); + let rt = Runtime::new(3, false, pd); + create_component(&rt, "", "constructor_outside_select", no_args()); + create_component(&rt, "", "constructor_inside_select", no_args()); +} + +#[test] +fn test_empty_select() { + let pd = ProtocolDescription::parse(b" + primitive constructor() { + u32 index = 0; + while (index < 5) { + sync select {} + index += 1; + } + } + ").expect("compilation"); + let rt = Runtime::new(3, false, pd); + create_component(&rt, "", "constructor", no_args()); } \ No newline at end of file