diff --git a/src/protocol/ast.rs b/src/protocol/ast.rs index aad546039890c93261d632c78ff150074333682f..9459e99d1b75e09768d4dbc41f496a118eb4f801 100644 --- a/src/protocol/ast.rs +++ b/src/protocol/ast.rs @@ -8,384 +8,78 @@ use super::arena::{Arena, Id}; // TODO: @cleanup, transform wrapping types into type aliases where possible use crate::protocol::inputsource::*; -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] -pub struct RootId(pub(crate) Id); - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct PragmaId(pub(crate) Id); - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct ImportId(pub(crate) Id); - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct TypeAnnotationId(pub(crate) Id); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] -pub struct VariableId(pub(crate) Id); - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] -pub struct ParameterId(pub(crate) VariableId); - -impl ParameterId { - pub fn upcast(self) -> VariableId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] -pub struct LocalId(pub(crate) VariableId); - -impl LocalId { - pub fn upcast(self) -> VariableId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] -pub struct DefinitionId(pub(crate) Id); - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct StructId(pub(crate) DefinitionId); - -impl StructId { - pub fn upcast(self) -> DefinitionId{ - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct EnumId(pub(crate) DefinitionId); - -impl EnumId { - pub fn upcast(self) -> DefinitionId{ - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct ComponentId(pub(crate) DefinitionId); - -impl ComponentId { - pub fn upcast(self) -> DefinitionId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct FunctionId(pub(crate) DefinitionId); - -impl FunctionId { - pub fn upcast(self) -> DefinitionId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct StatementId(pub(crate) Id); - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -// TODO: Remove pub -pub struct BlockStatementId(pub StatementId); - -impl BlockStatementId { - pub fn upcast(self) -> StatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct LocalStatementId(pub(crate) StatementId); - -impl LocalStatementId { - pub fn upcast(self) -> StatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct MemoryStatementId(pub(crate) LocalStatementId); - -impl MemoryStatementId { - pub fn upcast(self) -> LocalStatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct ChannelStatementId(pub(crate) LocalStatementId); - -impl ChannelStatementId { - pub fn upcast(self) -> LocalStatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct SkipStatementId(pub(crate) StatementId); - -impl SkipStatementId { - pub fn upcast(self) -> StatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct LabeledStatementId(pub(crate) StatementId); - -impl LabeledStatementId { - pub fn upcast(self) -> StatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct IfStatementId(pub(crate) StatementId); - -impl IfStatementId { - pub fn upcast(self) -> StatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct EndIfStatementId(pub(crate) StatementId); - -impl EndIfStatementId { - pub fn upcast(self) -> StatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct WhileStatementId(pub(crate) StatementId); - -impl WhileStatementId { - pub fn upcast(self) -> StatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct EndWhileStatementId(pub(crate) StatementId); - -impl EndWhileStatementId { - pub fn upcast(self) -> StatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct BreakStatementId(pub(crate) StatementId); - -impl BreakStatementId { - pub fn upcast(self) -> StatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct ContinueStatementId(pub(crate) StatementId); - -impl ContinueStatementId { - pub fn upcast(self) -> StatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct SynchronousStatementId(pub(crate) StatementId); - -impl SynchronousStatementId { - pub fn upcast(self) -> StatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct EndSynchronousStatementId(pub(crate) StatementId); - -impl EndSynchronousStatementId { - pub fn upcast(self) -> StatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct ReturnStatementId(pub(crate) StatementId); - -impl ReturnStatementId { - pub fn upcast(self) -> StatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct AssertStatementId(pub(crate) StatementId); - -impl AssertStatementId { - pub fn upcast(self) -> StatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct GotoStatementId(pub(crate) StatementId); - -impl GotoStatementId { - pub fn upcast(self) -> StatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct NewStatementId(pub(crate) StatementId); - -impl NewStatementId { - pub fn upcast(self) -> StatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct PutStatementId(pub(crate) StatementId); - -impl PutStatementId { - pub fn upcast(self) -> StatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct ExpressionStatementId(pub(crate) StatementId); - -impl ExpressionStatementId { - pub fn upcast(self) -> StatementId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct ExpressionId(pub(crate) Id); - -#[derive(Debug, Clone, Copy, serde::Serialize, serde::Deserialize)] -pub struct AssignmentExpressionId(pub(crate) ExpressionId); - -impl AssignmentExpressionId { - pub fn upcast(self) -> ExpressionId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct ConditionalExpressionId(pub(crate) ExpressionId); - -impl ConditionalExpressionId { - pub fn upcast(self) -> ExpressionId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct BinaryExpressionId(pub(crate) ExpressionId); - -impl BinaryExpressionId { - pub fn upcast(self) -> ExpressionId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct UnaryExpressionId(pub(crate) ExpressionId); - -impl UnaryExpressionId { - pub fn upcast(self) -> ExpressionId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct IndexingExpressionId(pub(crate) ExpressionId); - -impl IndexingExpressionId { - pub fn upcast(self) -> ExpressionId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct SlicingExpressionId(pub(crate) ExpressionId); - -impl SlicingExpressionId { - pub fn upcast(self) -> ExpressionId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct SelectExpressionId(pub(crate) ExpressionId); - -impl SelectExpressionId { - pub fn upcast(self) -> ExpressionId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct ArrayExpressionId(pub(crate) ExpressionId); - -impl ArrayExpressionId { - pub fn upcast(self) -> ExpressionId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct ConstantExpressionId(pub(crate) ExpressionId); - -impl ConstantExpressionId { - pub fn upcast(self) -> ExpressionId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct CallExpressionId(pub(crate) ExpressionId); - -impl CallExpressionId { - pub fn upcast(self) -> ExpressionId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct VariableExpressionId(pub(crate) ExpressionId); - -impl VariableExpressionId { - pub fn upcast(self) -> ExpressionId { - self.0 - } -} - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct DeclarationId(Id); - -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct DefinedDeclarationId(DeclarationId); - -impl DefinedDeclarationId { - pub fn upcast(self) -> DeclarationId { - self.0 +macro_rules! define_aliased_ast_id { + ($name:ident, $parent:ty) => { + type $name = $parent; } } -#[derive(Debug, Clone, Copy, PartialEq, serde::Serialize, serde::Deserialize)] -pub struct ImportedDeclarationId(DeclarationId); +macro_rules! define_new_ast_id { + ($name:ident, $parent:ty) => { + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] + pub struct $name (pub(crate) $parent); -impl ImportedDeclarationId { - pub fn upcast(self) -> DeclarationId { - self.0 - } -} + impl $name { + pub fn upcast(self) -> $parent { + self.0 + } + } + }; +} + +define_aliased_ast_id!(RootId, Id); +define_aliased_ast_id!(PragmaId, Id); +define_aliased_ast_id!(ImportId, Id); +define_aliased_ast_id!(TypeAnnotationId, Id); + +define_aliased_ast_id!(VariableId, Id); +define_new_ast_id!(ParameterId, VariableId); +define_new_ast_id!(LocalId, VariableId); + +define_aliased_ast_id!(DefinitionId, Id); +define_new_ast_id!(StructId, DefinitionId); +define_new_ast_id!(EnumId, DefinitionId); +define_new_ast_id!(ComponentId, DefinitionId); +define_new_ast_id!(FunctionId, DefinitionId); + +define_aliased_ast_id!(StatementId, Id); +define_new_ast_id!(BlockStatementId, StatementId); +define_new_ast_id!(LocalStatementId, StatementId); +define_new_ast_id!(MemoryStatementId, LocalStatementId); +define_new_ast_id!(ChannelStatementId, LocalStatementId); +define_new_ast_id!(SkipStatementId, StatementId); +define_new_ast_id!(LabeledStatementId, StatementId); +define_new_ast_id!(IfStatementId, StatementId); +define_new_ast_id!(EndIfStatementId, StatementId); +define_new_ast_id!(WhileStatementId, StatementId); +define_new_ast_id!(EndWhileStatementId, StatementId); +define_new_ast_id!(BreakStatementId, StatementId); +define_new_ast_id!(ContinueStatementId, StatementId); +define_new_ast_id!(SynchronousStatementId, StatementId); +define_new_ast_id!(EndSynchronousStatementId, StatementId); +define_new_ast_id!(ReturnStatementId, StatementId); +define_new_ast_id!(AssertStatementId, StatementId); +define_new_ast_id!(GotoStatementId, StatementId); +define_new_ast_id!(NewStatementId, StatementId); +define_new_ast_id!(PutStatementId, StatementId); +define_new_ast_id!(ExpressionStatementId, StatementId); + +define_aliased_ast_id!(ExpressionId, Id); +define_new_ast_id!(AssignmentExpressionId, ExpressionId); +define_new_ast_id!(ConditionalExpressionId, ExpressionId); +define_new_ast_id!(BinaryExpressionId, ExpressionId); +define_new_ast_id!(UnaryExpressionId, ExpressionId); +define_new_ast_id!(IndexingExpressionId, ExpressionId); +define_new_ast_id!(SlicingExpressionId, ExpressionId); +define_new_ast_id!(SelectExpressionId, ExpressionId); +define_new_ast_id!(ArrayExpressionId, ExpressionId); +define_new_ast_id!(ConstantExpressionId, ExpressionId); +define_new_ast_id!(CallExpressionId, ExpressionId); +define_new_ast_id!(VariableExpressionId, ExpressionId); + +define_new_ast_id!(DeclarationId, ExpressionId); // TODO: @cleanup +define_new_ast_id!(DefinedDeclarationId, DeclarationId); +define_new_ast_id!(ImportedDeclarationId, DeclarationId); // TODO: @cleanup - pub qualifiers can be removed once done #[derive(Debug, serde::Serialize, serde::Deserialize)] @@ -427,208 +121,212 @@ impl Heap { &mut self, f: impl FnOnce(TypeAnnotationId) -> TypeAnnotation, ) -> TypeAnnotationId { - TypeAnnotationId(self.type_annotations.alloc_with_id(|id| f(TypeAnnotationId(id)))) + self.type_annotations.alloc_with_id(|id| f(id)) } pub fn alloc_parameter(&mut self, f: impl FnOnce(ParameterId) -> Parameter) -> ParameterId { - ParameterId(VariableId( - self.variables.alloc_with_id(|id| Variable::Parameter(f(ParameterId(VariableId(id))))), - )) + ParameterId( + self.variables.alloc_with_id(|id| Variable::Parameter(f(ParameterId(id)))), + ) } pub fn alloc_local(&mut self, f: impl FnOnce(LocalId) -> Local) -> LocalId { - LocalId(VariableId( - self.variables.alloc_with_id(|id| Variable::Local(f(LocalId(VariableId(id))))), - )) + LocalId( + self.variables.alloc_with_id(|id| Variable::Local(f(LocalId(id)))), + ) } pub fn alloc_assignment_expression( &mut self, f: impl FnOnce(AssignmentExpressionId) -> AssignmentExpression, ) -> AssignmentExpressionId { - AssignmentExpressionId(ExpressionId(self.expressions.alloc_with_id(|id| { - Expression::Assignment(f(AssignmentExpressionId(ExpressionId(id)))) - }))) + AssignmentExpressionId( + self.expressions.alloc_with_id(|id| { + Expression::Assignment(f(AssignmentExpressionId(id))) + }) + ) } pub fn alloc_conditional_expression( &mut self, f: impl FnOnce(ConditionalExpressionId) -> ConditionalExpression, ) -> ConditionalExpressionId { - ConditionalExpressionId(ExpressionId(self.expressions.alloc_with_id(|id| { - Expression::Conditional(f(ConditionalExpressionId(ExpressionId(id)))) - }))) + ConditionalExpressionId( + self.expressions.alloc_with_id(|id| { + Expression::Conditional(f(ConditionalExpressionId(id))) + }) + ) } pub fn alloc_binary_expression( &mut self, f: impl FnOnce(BinaryExpressionId) -> BinaryExpression, ) -> BinaryExpressionId { - BinaryExpressionId(ExpressionId( + BinaryExpressionId( self.expressions - .alloc_with_id(|id| Expression::Binary(f(BinaryExpressionId(ExpressionId(id))))), - )) + .alloc_with_id(|id| Expression::Binary(f(BinaryExpressionId(id)))), + ) } pub fn alloc_unary_expression( &mut self, f: impl FnOnce(UnaryExpressionId) -> UnaryExpression, ) -> UnaryExpressionId { - UnaryExpressionId(ExpressionId( + UnaryExpressionId( self.expressions - .alloc_with_id(|id| Expression::Unary(f(UnaryExpressionId(ExpressionId(id))))), - )) + .alloc_with_id(|id| Expression::Unary(f(UnaryExpressionId(id)))), + ) } pub fn alloc_slicing_expression( &mut self, f: impl FnOnce(SlicingExpressionId) -> SlicingExpression, ) -> SlicingExpressionId { - SlicingExpressionId(ExpressionId( + SlicingExpressionId( self.expressions - .alloc_with_id(|id| Expression::Slicing(f(SlicingExpressionId(ExpressionId(id))))), - )) + .alloc_with_id(|id| Expression::Slicing(f(SlicingExpressionId(id)))), + ) } pub fn alloc_indexing_expression( &mut self, f: impl FnOnce(IndexingExpressionId) -> IndexingExpression, ) -> IndexingExpressionId { - IndexingExpressionId(ExpressionId( + IndexingExpressionId( self.expressions.alloc_with_id(|id| { - Expression::Indexing(f(IndexingExpressionId(ExpressionId(id)))) + Expression::Indexing(f(IndexingExpressionId(id))) }), - )) + ) } pub fn alloc_select_expression( &mut self, f: impl FnOnce(SelectExpressionId) -> SelectExpression, ) -> SelectExpressionId { - SelectExpressionId(ExpressionId( + SelectExpressionId( self.expressions - .alloc_with_id(|id| Expression::Select(f(SelectExpressionId(ExpressionId(id))))), - )) + .alloc_with_id(|id| Expression::Select(f(SelectExpressionId(id)))), + ) } pub fn alloc_array_expression( &mut self, f: impl FnOnce(ArrayExpressionId) -> ArrayExpression, ) -> ArrayExpressionId { - ArrayExpressionId(ExpressionId( + ArrayExpressionId( self.expressions - .alloc_with_id(|id| Expression::Array(f(ArrayExpressionId(ExpressionId(id))))), - )) + .alloc_with_id(|id| Expression::Array(f(ArrayExpressionId(id)))), + ) } pub fn alloc_constant_expression( &mut self, f: impl FnOnce(ConstantExpressionId) -> ConstantExpression, ) -> ConstantExpressionId { - ConstantExpressionId(ExpressionId( + ConstantExpressionId( self.expressions.alloc_with_id(|id| { - Expression::Constant(f(ConstantExpressionId(ExpressionId(id)))) + Expression::Constant(f(ConstantExpressionId(id))) }), - )) + ) } pub fn alloc_call_expression( &mut self, f: impl FnOnce(CallExpressionId) -> CallExpression, ) -> CallExpressionId { - CallExpressionId(ExpressionId( + CallExpressionId( self.expressions - .alloc_with_id(|id| Expression::Call(f(CallExpressionId(ExpressionId(id))))), - )) + .alloc_with_id(|id| Expression::Call(f(CallExpressionId(id)))), + ) } pub fn alloc_variable_expression( &mut self, f: impl FnOnce(VariableExpressionId) -> VariableExpression, ) -> VariableExpressionId { - VariableExpressionId(ExpressionId( + VariableExpressionId( self.expressions.alloc_with_id(|id| { - Expression::Variable(f(VariableExpressionId(ExpressionId(id)))) + Expression::Variable(f(VariableExpressionId(id))) }), - )) + ) } pub fn alloc_block_statement( &mut self, f: impl FnOnce(BlockStatementId) -> BlockStatement, ) -> BlockStatementId { - BlockStatementId(StatementId( + BlockStatementId( self.statements - .alloc_with_id(|id| Statement::Block(f(BlockStatementId(StatementId(id))))), - )) + .alloc_with_id(|id| Statement::Block(f(BlockStatementId(id)))), + ) } pub fn alloc_memory_statement( &mut self, f: impl FnOnce(MemoryStatementId) -> MemoryStatement, ) -> MemoryStatementId { - MemoryStatementId(LocalStatementId(StatementId(self.statements.alloc_with_id(|id| { - Statement::Local(LocalStatement::Memory(f(MemoryStatementId(LocalStatementId( - StatementId(id), - ))))) - })))) + MemoryStatementId(LocalStatementId(self.statements.alloc_with_id(|id| { + Statement::Local(LocalStatement::Memory( + f(MemoryStatementId(LocalStatementId(id))) + )) + }))) } pub fn alloc_channel_statement( &mut self, f: impl FnOnce(ChannelStatementId) -> ChannelStatement, ) -> ChannelStatementId { - ChannelStatementId(LocalStatementId(StatementId(self.statements.alloc_with_id(|id| { - Statement::Local(LocalStatement::Channel(f(ChannelStatementId(LocalStatementId( - StatementId(id), - ))))) - })))) + ChannelStatementId(LocalStatementId(self.statements.alloc_with_id(|id| { + Statement::Local(LocalStatement::Channel( + f(ChannelStatementId(LocalStatementId(id))) + )) + }))) } pub fn alloc_skip_statement( &mut self, f: impl FnOnce(SkipStatementId) -> SkipStatement, ) -> SkipStatementId { - SkipStatementId(StatementId( + SkipStatementId( self.statements - .alloc_with_id(|id| Statement::Skip(f(SkipStatementId(StatementId(id))))), - )) + .alloc_with_id(|id| Statement::Skip(f(SkipStatementId(id)))), + ) } pub fn alloc_if_statement( &mut self, f: impl FnOnce(IfStatementId) -> IfStatement, ) -> IfStatementId { - IfStatementId(StatementId( - self.statements.alloc_with_id(|id| Statement::If(f(IfStatementId(StatementId(id))))), - )) + IfStatementId( + self.statements.alloc_with_id(|id| Statement::If(f(IfStatementId(id)))), + ) } pub fn alloc_end_if_statement( &mut self, f: impl FnOnce(EndIfStatementId) -> EndIfStatement, ) -> EndIfStatementId { - EndIfStatementId(StatementId( + EndIfStatementId( self.statements - .alloc_with_id(|id| Statement::EndIf(f(EndIfStatementId(StatementId(id))))), - )) + .alloc_with_id(|id| Statement::EndIf(f(EndIfStatementId(id)))), + ) } pub fn alloc_while_statement( &mut self, f: impl FnOnce(WhileStatementId) -> WhileStatement, ) -> WhileStatementId { - WhileStatementId(StatementId( + WhileStatementId( self.statements - .alloc_with_id(|id| Statement::While(f(WhileStatementId(StatementId(id))))), - )) + .alloc_with_id(|id| Statement::While(f(WhileStatementId(id)))), + ) } pub fn alloc_end_while_statement( &mut self, f: impl FnOnce(EndWhileStatementId) -> EndWhileStatement, ) -> EndWhileStatementId { - EndWhileStatementId(StatementId( + EndWhileStatementId( self.statements - .alloc_with_id(|id| Statement::EndWhile(f(EndWhileStatementId(StatementId(id))))), - )) + .alloc_with_id(|id| Statement::EndWhile(f(EndWhileStatementId(id)))), + ) } pub fn alloc_break_statement( &mut self, f: impl FnOnce(BreakStatementId) -> BreakStatement, ) -> BreakStatementId { - BreakStatementId(StatementId( + BreakStatementId( self.statements - .alloc_with_id(|id| Statement::Break(f(BreakStatementId(StatementId(id))))), - )) + .alloc_with_id(|id| Statement::Break(f(BreakStatementId(id)))), + ) } pub fn alloc_continue_statement( &mut self, f: impl FnOnce(ContinueStatementId) -> ContinueStatement, ) -> ContinueStatementId { - ContinueStatementId(StatementId( + ContinueStatementId( self.statements - .alloc_with_id(|id| Statement::Continue(f(ContinueStatementId(StatementId(id))))), - )) + .alloc_with_id(|id| Statement::Continue(f(ContinueStatementId(id)))), + ) } pub fn alloc_synchronous_statement( &mut self, diff --git a/src/protocol/parser/mod.rs b/src/protocol/parser/mod.rs index 331bf24f468336f6c960ffe784e78f76f22f59a6..a16c39a9b30955d01c309822c3f29022194dc703 100644 --- a/src/protocol/parser/mod.rs +++ b/src/protocol/parser/mod.rs @@ -3,10 +3,12 @@ mod symbol_table; mod type_table; mod type_resolver; mod visitor; +mod visitor_linker; use depth_visitor::*; use symbol_table::SymbolTable; -use visitor::{Visitor2, ValidityAndLinkerVisitor}; +use visitor::Visitor2; +use visitor_linker::ValidityAndLinkerVisitor; use type_table::TypeTable; use crate::protocol::ast::*; diff --git a/src/protocol/parser/type_resolver.rs b/src/protocol/parser/type_resolver.rs index 40ecfb15e70169197073d67d854c0115fb799946..78f464e91809554092bfefc13e2a24bd77614ee5 100644 --- a/src/protocol/parser/type_resolver.rs +++ b/src/protocol/parser/type_resolver.rs @@ -1,4 +1,18 @@ -use super::visitor::{Visitor2, VisitorResult}; +use crate::protocol::ast::*; +use super::visitor::{ + STMT_BUFFER_INIT_CAPACITY, + EXPR_BUFFER_INIT_CAPACITY, + Ctx, + Visitor2, + VisitorResult +}; + +enum ExprType { + Regular, // expression statement or return statement + Memory, // memory statement's expression + Condition, // if/while conditional statement + Assert, // assert statement +} /// This particular visitor will recurse depth-first into the AST and ensures /// that all expressions have the appropriate types. At the moment this implies: @@ -10,9 +24,85 @@ use super::visitor::{Visitor2, VisitorResult}; /// This will be achieved by slowly descending into the AST. At any given /// expression we may depend on pub(crate) struct TypeResolvingVisitor { + // Tracking traversal state + expr_type: ExprType, + + // Buffers for iteration over substatements and subexpressions + stmt_buffer: Vec, + expr_buffer: Vec, +} + +impl TypeResolvingVisitor { + pub(crate) fn new() -> Self { + TypeResolvingVisitor{ + expr_type: ExprType::Regular, + stmt_buffer: Vec::with_capacity(STMT_BUFFER_INIT_CAPACITY), + expr_buffer: Vec::with_capacity(EXPR_BUFFER_INIT_CAPACITY), + } + } + fn reset(&mut self) { + self.expr_type = ExprType::Regular; + } } impl Visitor2 for TypeResolvingVisitor { + // Definitions + + fn visit_component_definition(&mut self, ctx: &mut Ctx, id: ComponentId) -> VisitorResult { + let body_stmt_id = ctx.heap[id].body; + self.visit_stmt(ctx, body_stmt_id) + } + + fn visit_function_definition(&mut self, ctx: &mut Ctx, id: FunctionId) -> VisitorResult { + let body_stmt_id = ctx.heap[id].body; + self.visit_stmt(ctx, body_stmt_id) + } + + // Statements + + fn visit_block_stmt(&mut self, ctx: &mut Ctx, id: BlockStatementId) -> VisitorResult { + // Transfer statements for traversal + let block = &ctx.heap[id]; + + let old_len_stmts = self.stmt_buffer.len(); + self.stmt_buffer.extend(&block.statements); + let new_len_stmts = self.stmt_buffer.len(); + + // Traverse statements + for stmt_idx in old_len_stmts..new_len_stmts { + let stmt_id = self.stmt_buffer[stmt_idx]; + self.expr_type = ExprType::Regular; + self.visit_stmt(ctx, stmt_id)?; + } + + self.stmt_buffer.truncate(old_len_stmts); + Ok(()) + } + + fn visit_local_memory_stmt(&mut self, ctx: &mut Ctx, id: MemoryStatementId) -> VisitorResult { + let memory_stmt = &ctx.heap[id]; + + // Type of local should match the type of the initial expression + + + // For now, with all variables having an explicit type, it seems we + // do not need to consider all expressions within a single definition in + // order to do typechecking and type inference for numeric constants. + + // Since each expression will get an assigned type, and everything + // already has a type, we may traverse leaf-to-root, while assigning + // output types. We throw an error if the types do not match. If an + // expression's type is already assigned then they should match. + + // Note that for numerical types the information may also travel upward. + // That is: if we have: + // + // u32 a = 5 * 2 << 3 + 8 + // + // Then we may infer that the expression yields a u32 type. As a result + // All of the literals 5, 2, 3 and 8 will have type u32 as well. + Ok(()) + } } \ No newline at end of file diff --git a/src/protocol/parser/visitor.rs b/src/protocol/parser/visitor.rs index 86aac5fc1b59a89d9000109abe1bcd0a0e5b68ca..02bfaf21b5242a657c1421a777248e2844de5e29 100644 --- a/src/protocol/parser/visitor.rs +++ b/src/protocol/parser/visitor.rs @@ -5,6 +5,14 @@ use crate::protocol::parser::{symbol_table::*, type_table::*, LexedModule}; type Unit = (); pub(crate) type VisitorResult = Result; +/// Globally configured vector capacity for statement buffers in visitor +/// implementations +pub(crate) const STMT_BUFFER_INIT_CAPACITY: usize = 256; +/// Globally configured vector capacity for expression buffers in visitor +/// implementations +pub(crate) const EXPR_BUFFER_INIT_CAPACITY: usize = 256; + +/// General context structure that is used while traversing the AST. pub(crate) struct Ctx<'p> { pub heap: &'p mut Heap, pub module: &'p LexedModule, @@ -227,1075 +235,4 @@ pub(crate) trait Visitor2 { fn visit_constant_expr(&mut self, _ctx: &mut Ctx, _id: ConstantExpressionId) -> VisitorResult { Ok(()) } fn visit_call_expr(&mut self, _ctx: &mut Ctx, _id: CallExpressionId) -> VisitorResult { Ok(()) } fn visit_variable_expr(&mut self, _ctx: &mut Ctx, _id: VariableExpressionId) -> VisitorResult { Ok(()) } -} - -#[derive(PartialEq, Eq)] -enum DefinitionType { - Primitive, - Composite, - Function -} - -/// This particular visitor will go through the entire AST in a recursive manner -/// and check if all statements and expressions are legal (e.g. no "return" -/// statements in component definitions), and will link certain AST nodes to -/// their appropriate targets (e.g. goto statements, or function calls). -/// -/// This visitor will not perform control-flow analysis (e.g. making sure that -/// each function actually returns) and will also not perform type checking. So -/// the linking of function calls and component instantiations will be checked -/// and linked to the appropriate definitions, but the return types and/or -/// arguments will not be checked for validity. -/// -/// The visitor visits each statement in a block in a breadth-first manner -/// first. We are thereby sure that we have found all variables/labels in a -/// particular block. In this phase nodes may queue statements for insertion -/// (e.g. the insertion of an `EndIf` statement for a particular `If` -/// statement). These will be inserted after visiting every node, after which -/// the visitor recurses into each statement in a block. -/// -/// Because of this scheme expressions will not be visited in the breadth-first -/// pass. -pub(crate) struct ValidityAndLinkerVisitor { - /// `in_sync` is `Some(id)` if the visitor is visiting the children of a - /// synchronous statement. A single value is sufficient as nested - /// synchronous statements are not allowed - in_sync: Option, - /// `in_while` contains the last encountered `While` statement. This is used - /// to resolve unlabeled `Continue`/`Break` statements. - in_while: Option, - // Traversal state: current scope (which can be used to find the parent - // scope), the definition variant we are considering, and whether the - // visitor is performing breadthwise block statement traversal. - cur_scope: Option, - def_type: DefinitionType, - performing_breadth_pass: bool, - // Keeping track of relative position in block in the breadth-first pass. - // May not correspond to block.statement[index] if any statements are - // inserted after the breadth-pass - relative_pos_in_block: u32, - // Single buffer of statement IDs that we want to traverse in a block. - // Required to work around Rust borrowing rules and to prevent constant - // cloning of vectors. - statement_buffer: Vec, - // Another buffer, now with expression IDs, to prevent constant cloning of - // vectors while working around rust's borrowing rules - expression_buffer: Vec, - // Statements to insert after the breadth pass in a single block - insert_buffer: Vec<(u32, StatementId)>, -} - -impl ValidityAndLinkerVisitor { - pub(crate) fn new() -> Self { - Self{ - in_sync: None, - in_while: None, - cur_scope: None, - def_type: DefinitionType::Primitive, - performing_breadth_pass: false, - relative_pos_in_block: 0, - statement_buffer: Vec::with_capacity(256), - expression_buffer: Vec::with_capacity(256), - insert_buffer: Vec::with_capacity(32), - } - } - - fn reset_state(&mut self) { - self.in_sync = None; - self.in_while = None; - self.cur_scope = None; - self.def_type = DefinitionType::Primitive; - self.relative_pos_in_block = 0; - self.performing_breadth_pass = false; - self.statement_buffer.clear(); - self.insert_buffer.clear(); - } -} - -impl Visitor2 for ValidityAndLinkerVisitor { - //-------------------------------------------------------------------------- - // Definition visitors - //-------------------------------------------------------------------------- - - fn visit_component_definition(&mut self, ctx: &mut Ctx, id: ComponentId) -> VisitorResult { - self.reset_state(); - - self.def_type = match &ctx.heap[id].variant { - ComponentVariant::Primitive => DefinitionType::Primitive, - ComponentVariant::Composite => DefinitionType::Composite, - }; - self.cur_scope = Some(Scope::Definition(id.upcast())); - let body_id = ctx.heap[id].body; - - self.performing_breadth_pass = true; - self.visit_stmt(ctx, body_id)?; - self.performing_breadth_pass = false; - self.visit_stmt(ctx, body_id) - } - - fn visit_function_definition(&mut self, ctx: &mut Ctx, id: FunctionId) -> VisitorResult { - self.reset_state(); - - // Set internal statement indices - self.def_type = DefinitionType::Function; - self.cur_scope = Some(Scope::Definition(id.upcast())); - let body_id = ctx.heap[id].body; - - self.performing_breadth_pass = true; - self.visit_stmt(ctx, body_id)?; - self.performing_breadth_pass = false; - self.visit_stmt(ctx, body_id) - } - - //-------------------------------------------------------------------------- - // Statement visitors - //-------------------------------------------------------------------------- - - fn visit_block_stmt(&mut self, ctx: &mut Ctx, id: BlockStatementId) -> VisitorResult { - self.visit_block_stmt_with_hint(ctx, id, None) - } - - fn visit_local_memory_stmt(&mut self, ctx: &mut Ctx, id: MemoryStatementId) -> VisitorResult { - if self.performing_breadth_pass { - let variable_id = ctx.heap[id].variable; - self.checked_local_add(ctx, self.relative_pos_in_block, variable_id)?; - } else { - self.visit_expr(ctx, ctx.heap[id].initial)?; - } - - Ok(()) - } - - fn visit_local_channel_stmt(&mut self, ctx: &mut Ctx, id: ChannelStatementId) -> VisitorResult { - if self.performing_breadth_pass { - let (from_id, to_id) = { - let stmt = &ctx.heap[id]; - (stmt.from, stmt.to) - }; - self.checked_local_add(ctx, self.relative_pos_in_block, from_id)?; - self.checked_local_add(ctx, self.relative_pos_in_block, to_id)?; - } - - Ok(()) - } - - fn visit_labeled_stmt(&mut self, ctx: &mut Ctx, id: LabeledStatementId) -> VisitorResult { - if self.performing_breadth_pass { - // Add label to block lookup - self.checked_label_add(ctx, id)?; - - // Modify labeled statement itself - let labeled = &mut ctx.heap[id]; - labeled.relative_pos_in_block = self.relative_pos_in_block; - labeled.in_sync = self.in_sync.clone(); - } - - let body_id = ctx.heap[id].body; - self.visit_stmt(ctx, body_id)?; - - Ok(()) - } - - fn visit_if_stmt(&mut self, ctx: &mut Ctx, id: IfStatementId) -> VisitorResult { - if self.performing_breadth_pass { - let position = ctx.heap[id].position; - let end_if_id = ctx.heap.alloc_end_if_statement(|this| { - EndIfStatement { - this, - start_if: id, - position, - next: None, - } - }); - let stmt = &mut ctx.heap[id]; - stmt.end_if = Some(end_if_id); - self.insert_buffer.push((self.relative_pos_in_block + 1, end_if_id.upcast())); - } else { - // Traverse expression and bodies - let (test_id, true_id, false_id) = { - let stmt = &ctx.heap[id]; - (stmt.test, stmt.true_body, stmt.false_body) - }; - self.visit_expr(ctx, test_id)?; - self.visit_stmt(ctx, true_id)?; - self.visit_stmt(ctx, false_id)?; - } - - Ok(()) - } - - fn visit_while_stmt(&mut self, ctx: &mut Ctx, id: WhileStatementId) -> VisitorResult { - if self.performing_breadth_pass { - let position = ctx.heap[id].position; - let end_while_id = ctx.heap.alloc_end_while_statement(|this| { - EndWhileStatement { - this, - start_while: id, - position, - next: None, - } - }); - let stmt = &mut ctx.heap[id]; - stmt.end_while = Some(end_while_id); - stmt.in_sync = self.in_sync.clone(); - - self.insert_buffer.push((self.relative_pos_in_block + 1, end_while_id.upcast())); - } else { - let (test_id, body_id) = { - let stmt = &ctx.heap[id]; - (stmt.test, stmt.body) - }; - let old_while = self.in_while.replace(id); - self.visit_expr(ctx, test_id)?; - self.visit_stmt(ctx, body_id)?; - self.in_while = old_while; - } - - Ok(()) - } - - fn visit_break_stmt(&mut self, ctx: &mut Ctx, id: BreakStatementId) -> VisitorResult { - if self.performing_breadth_pass { - // Should be able to resolve break statements with a label in the - // breadth pass, no need to do after resolving all labels - let target_end_while = { - let stmt = &ctx.heap[id]; - let target_while_id = self.resolve_break_or_continue_target(ctx, stmt.position, &stmt.label)?; - let target_while = &ctx.heap[target_while_id]; - debug_assert!(target_while.end_while.is_some()); - target_while.end_while.unwrap() - }; - - let stmt = &mut ctx.heap[id]; - stmt.target = Some(target_end_while); - } - - Ok(()) - } - - fn visit_continue_stmt(&mut self, ctx: &mut Ctx, id: ContinueStatementId) -> VisitorResult { - if self.performing_breadth_pass { - let target_while_id = { - let stmt = &ctx.heap[id]; - self.resolve_break_or_continue_target(ctx, stmt.position, &stmt.label)? - }; - - let stmt = &mut ctx.heap[id]; - stmt.target = Some(target_while_id) - } - - Ok(()) - } - - fn visit_synchronous_stmt(&mut self, ctx: &mut Ctx, id: SynchronousStatementId) -> VisitorResult { - if self.performing_breadth_pass { - // Check for validity of synchronous statement - let cur_sync_position = ctx.heap[id].position; - if self.in_sync.is_some() { - // Nested synchronous statement - let old_sync = &ctx.heap[self.in_sync.unwrap()]; - return Err( - ParseError2::new_error(&ctx.module.source, cur_sync_position, "Illegal nested synchronous statement") - .with_postfixed_info(&ctx.module.source, old_sync.position, "It is nested in this synchronous statement") - ); - } - - if self.def_type != DefinitionType::Primitive { - return Err(ParseError2::new_error( - &ctx.module.source, cur_sync_position, - "Synchronous statements may only be used in primitive components" - )); - } - - // Append SynchronousEnd pseudo-statement - let sync_end_id = ctx.heap.alloc_end_synchronous_statement(|this| EndSynchronousStatement{ - this, - position: cur_sync_position, - start_sync: id, - next: None, - }); - let sync_start = &mut ctx.heap[id]; - sync_start.end_sync = Some(sync_end_id); - self.insert_buffer.push((self.relative_pos_in_block + 1, sync_end_id.upcast())); - } else { - let sync_body = ctx.heap[id].body; - let old = self.in_sync.replace(id); - self.visit_stmt_with_hint(ctx, sync_body, Some(id))?; - self.in_sync = old; - } - - Ok(()) - } - - fn visit_return_stmt(&mut self, ctx: &mut Ctx, id: ReturnStatementId) -> VisitorResult { - if self.performing_breadth_pass { - let stmt = &ctx.heap[id]; - if self.def_type != DefinitionType::Function { - return Err( - ParseError2::new_error(&ctx.module.source, stmt.position, "Return statements may only appear in function bodies") - ); - } - } else { - // If here then we are within a function - self.visit_expr(ctx, ctx.heap[id].expression)?; - } - - Ok(()) - } - - fn visit_assert_stmt(&mut self, ctx: &mut Ctx, id: AssertStatementId) -> VisitorResult { - let stmt = &ctx.heap[id]; - if self.performing_breadth_pass { - if self.def_type == DefinitionType::Function { - // TODO: We probably want to allow this. Mark the function as - // using asserts, and then only allow calls to these functions - // within components. Such a marker will cascade through any - // functions that then call an asserting function - return Err( - ParseError2::new_error(&ctx.module.source, stmt.position, "Illegal assert statement in a function") - ); - } - - // We are in a component of some sort, but we also need to be within a - // synchronous statement - if self.in_sync.is_none() { - return Err( - ParseError2::new_error(&ctx.module.source, stmt.position, "Illegal assert statement outside of a synchronous block") - ); - } - } else { - let expr_id = stmt.expression; - self.visit_expr(ctx, expr_id)?; - } - - Ok(()) - } - - fn visit_goto_stmt(&mut self, ctx: &mut Ctx, id: GotoStatementId) -> VisitorResult { - if !self.performing_breadth_pass { - // Must perform goto label resolving after the breadth pass, this - // way we are able to find all the labels in current and outer - // scopes. - let target_id = self.find_label(ctx, &ctx.heap[id].label)?; - ctx.heap[id].target = Some(target_id); - - let target = &ctx.heap[target_id]; - if self.in_sync != target.in_sync { - // We can only goto the current scope or outer scopes. Because - // nested sync statements are not allowed so if the value does - // not match, then we must be inside a sync scope - debug_assert!(self.in_sync.is_some()); - let goto_stmt = &ctx.heap[id]; - let sync_stmt = &ctx.heap[self.in_sync.unwrap()]; - return Err( - ParseError2::new_error(&ctx.module.source, goto_stmt.position, "Goto may not escape the surrounding synchronous block") - .with_postfixed_info(&ctx.module.source, target.position, "This is the target of the goto statement") - .with_postfixed_info(&ctx.module.source, sync_stmt.position, "Which will jump past this statement") - ); - } - } - - Ok(()) - } - - fn visit_new_stmt(&mut self, ctx: &mut Ctx, id: NewStatementId) -> VisitorResult { - // Link the call expression following the new statement - if self.performing_breadth_pass { - // TODO: Cleanup error messages, can be done cleaner - // Make sure new statement occurs within a composite component - let call_expr_id = ctx.heap[id].expression; - if self.def_type != DefinitionType::Composite { - let new_stmt = &ctx.heap[id]; - return Err( - ParseError2::new_error(&ctx.module.source, new_stmt.position, "Instantiating components may only be done in composite components") - ); - } - - // No fancy recursive parsing, must be followed by a call expression - let definition_id = { - let call_expr = &ctx.heap[call_expr_id]; - if let Method::Symbolic(symbolic) = &call_expr.method { - let found_symbol = self.find_symbol_of_type( - ctx.module.root_id, &ctx.symbols, &ctx.types, - &symbolic.identifier, TypeClass::Component - ); - - match found_symbol { - FindOfTypeResult::Found(definition_id) => definition_id, - FindOfTypeResult::TypeMismatch(got_type_class) => { - return Err(ParseError2::new_error( - &ctx.module.source, symbolic.identifier.position, - &format!("New must instantiate a component, this identifier points to a {}", got_type_class) - )) - }, - FindOfTypeResult::NotFound => { - return Err(ParseError2::new_error( - &ctx.module.source, symbolic.identifier.position, - "Could not find a defined component with this name" - )) - } - } - } else { - return Err( - ParseError2::new_error(&ctx.module.source, call_expr.position, "Must instantiate a component") - ); - } - }; - - // Modify new statement's symbolic call to point to the appropriate - // definition. - let call_expr = &mut ctx.heap[call_expr_id]; - match &mut call_expr.method { - Method::Symbolic(method) => method.definition = Some(definition_id), - _ => unreachable!() - } - } else { - // Performing depth pass. The function definition should have been - // resolved in the breadth pass, now we recurse to evaluate the - // arguments - let call_expr_id = ctx.heap[id].expression; - let call_expr = &ctx.heap[call_expr_id]; - - let old_num_exprs = self.expression_buffer.len(); - self.expression_buffer.extend(&call_expr.arguments); - let new_num_exprs = self.expression_buffer.len(); - - for arg_expr_idx in old_num_exprs..new_num_exprs { - let arg_expr_id = self.expression_buffer[arg_expr_idx]; - self.visit_expr(ctx, arg_expr_id)?; - } - - self.expression_buffer.truncate(old_num_exprs); - } - - Ok(()) - } - - fn visit_put_stmt(&mut self, ctx: &mut Ctx, id: PutStatementId) -> VisitorResult { - // TODO: Make `put` an expression. Perhaps silly, but much easier to - // perform typechecking - if self.performing_breadth_pass { - let put_stmt = &ctx.heap[id]; - if self.in_sync.is_none() { - return Err(ParseError2::new_error( - &ctx.module.source, put_stmt.position, "Put must be called in a synchronous block" - )); - } - } else { - let put_stmt = &ctx.heap[id]; - let port = put_stmt.port; - let message = put_stmt.message; - self.visit_expr(ctx, port)?; - self.visit_expr(ctx, message)?; - } - - Ok(()) - } - - fn visit_expr_stmt(&mut self, ctx: &mut Ctx, id: ExpressionStatementId) -> VisitorResult { - if !self.performing_breadth_pass { - let expr_id = ctx.heap[id].expression; - self.visit_expr(ctx, expr_id)?; - } - - Ok(()) - } - - - //-------------------------------------------------------------------------- - // Expression visitors - //-------------------------------------------------------------------------- - - fn visit_assignment_expr(&mut self, ctx: &mut Ctx, id: AssignmentExpressionId) -> VisitorResult { - debug_assert!(!self.performing_breadth_pass); - let assignment_expr = &ctx.heap[id]; - let left_expr_id = assignment_expr.left; - let right_expr_id = assignment_expr.right; - self.visit_expr(ctx, left_expr_id)?; - self.visit_expr(ctx, right_expr_id)?; - Ok(()) - } - - fn visit_conditional_expr(&mut self, ctx: &mut Ctx, id: ConditionalExpressionId) -> VisitorResult { - debug_assert!(!self.performing_breadth_pass); - let conditional_expr = &ctx.heap[id]; - let test_expr_id = conditional_expr.test; - let true_expr_id = conditional_expr.true_expression; - let false_expr_id = conditional_expr.false_expression; - self.visit_expr(ctx, test_expr_id)?; - self.visit_expr(ctx, true_expr_id)?; - self.visit_expr(ctx, false_expr_id)?; - Ok(()) - } - - fn visit_binary_expr(&mut self, ctx: &mut Ctx, id: BinaryExpressionId) -> VisitorResult { - debug_assert!(!self.performing_breadth_pass); - let binary_expr = &ctx.heap[id]; - let left_expr_id = binary_expr.left; - let right_expr_id = binary_expr.right; - self.visit_expr(ctx, left_expr_id)?; - self.visit_expr(ctx, right_expr_id)?; - Ok(()) - } - - fn visit_unary_expr(&mut self, ctx: &mut Ctx, id: UnaryExpressionId) -> VisitorResult { - debug_assert!(!self.performing_breadth_pass); - let expr_id = ctx.heap[id].expression; - self.visit_expr(ctx, expr_id)?; - Ok(()) - } - - fn visit_indexing_expr(&mut self, ctx: &mut Ctx, id: IndexingExpressionId) -> VisitorResult { - debug_assert!(!self.performing_breadth_pass); - let indexing_expr = &ctx.heap[id]; - let subject_expr_id = indexing_expr.subject; - let index_expr_id = indexing_expr.index; - self.visit_expr(ctx, subject_expr_id)?; - self.visit_expr(ctx, index_expr_id)?; - Ok(()) - } - - fn visit_slicing_expr(&mut self, ctx: &mut Ctx, id: SlicingExpressionId) -> VisitorResult { - debug_assert!(!self.performing_breadth_pass); - // TODO: Same as the select expression: slicing depends on the type of - // the thing that is being sliced. - let slicing_expr = &ctx.heap[id]; - let subject_expr_id = slicing_expr.subject; - let from_expr_id = slicing_expr.from_index; - let to_expr_id = slicing_expr.to_index; - self.visit_expr(ctx, subject_expr_id)?; - self.visit_expr(ctx, from_expr_id)?; - self.visit_expr(ctx, to_expr_id)?; - Ok(()) - } - - fn visit_select_expr(&mut self, ctx: &mut Ctx, id: SelectExpressionId) -> VisitorResult { - debug_assert!(!self.performing_breadth_pass); - let expr_id = ctx.heap[id].subject; - self.visit_expr(ctx, expr_id)?; - - Ok(()) - } - - fn visit_array_expr(&mut self, ctx: &mut Ctx, id: ArrayExpressionId) -> VisitorResult { - debug_assert!(!self.performing_breadth_pass); - let array_expr = &ctx.heap[id]; - - let old_num_exprs = self.expression_buffer.len(); - self.expression_buffer.extend(&array_expr.elements); - let new_num_exprs = self.expression_buffer.len(); - - for field_expr_idx in old_num_exprs..new_num_exprs { - let field_expr_id = self.expression_buffer[field_expr_idx]; - self.visit_expr(ctx, field_expr_id)?; - } - - self.expression_buffer.truncate(old_num_exprs); - - Ok(()) - } - - fn visit_constant_expr(&mut self, _ctx: &mut Ctx, _id: ConstantExpressionId) -> VisitorResult { - debug_assert!(!self.performing_breadth_pass); - Ok(()) - } - - fn visit_call_expr(&mut self, ctx: &mut Ctx, id: CallExpressionId) -> VisitorResult { - debug_assert!(!self.performing_breadth_pass); - - let call_expr = &mut ctx.heap[id]; - - // Resolve the method to the appropriate definition and check the - // legality of the particular method call. - match &mut call_expr.method { - Method::Create => {}, - Method::Fires => { - if self.def_type != DefinitionType::Primitive { - return Err(ParseError2::new_error( - &ctx.module.source, call_expr.position, - "A call to 'fires' may only occur in primitive component definitions" - )); - } - }, - Method::Get => { - if self.def_type != DefinitionType::Primitive { - return Err(ParseError2::new_error( - &ctx.module.source, call_expr.position, - "A call to 'get' may only occur in primitive component definitions" - )); - } - }, - Method::Symbolic(symbolic) => { - // Find symbolic method - let found_symbol = self.find_symbol_of_type( - ctx.module.root_id, &ctx.symbols, &ctx.types, - &symbolic.identifier, TypeClass::Function - ); - let definition_id = match found_symbol { - FindOfTypeResult::Found(definition_id) => definition_id, - FindOfTypeResult::TypeMismatch(got_type_class) => { - return Err(ParseError2::new_error( - &ctx.module.source, symbolic.identifier.position, - &format!("Only functions can be called, this identifier points to a {}", got_type_class) - )) - }, - FindOfTypeResult::NotFound => { - return Err(ParseError2::new_error( - &ctx.module.source, symbolic.identifier.position, - &format!("Could not find a function with this name") - )) - } - }; - - symbolic.definition = Some(definition_id); - } - } - - // Parse all the arguments in the depth pass as well - let call_expr = &mut ctx.heap[id]; - let old_num_exprs = self.expression_buffer.len(); - self.expression_buffer.extend(&call_expr.arguments); - let new_num_exprs = self.expression_buffer.len(); - - for arg_expr_idx in old_num_exprs..new_num_exprs { - let arg_expr_id = self.expression_buffer[arg_expr_idx]; - self.visit_expr(ctx, arg_expr_id)?; - } - - self.expression_buffer.truncate(old_num_exprs); - - Ok(()) - } - - fn visit_variable_expr(&mut self, ctx: &mut Ctx, id: VariableExpressionId) -> VisitorResult { - debug_assert!(!self.performing_breadth_pass); - - let var_expr = &ctx.heap[id]; - let variable_id = self.find_variable(ctx, self.relative_pos_in_block, &var_expr.identifier)?; - let var_expr = &mut ctx.heap[id]; - var_expr.declaration = Some(variable_id); - - Ok(()) - } -} - -enum FindOfTypeResult { - // Identifier was exactly matched, type matched as well - Found(DefinitionId), - // Identifier was matched, but the type differs from the expected one - TypeMismatch(&'static str), - // Identifier could not be found - NotFound, -} - -impl ValidityAndLinkerVisitor { - //-------------------------------------------------------------------------- - // Special traversal - //-------------------------------------------------------------------------- - - /// Will visit a statement with a hint about its wrapping statement. This is - /// used to distinguish block statements with a wrapping synchronous - /// statement from normal block statements. - fn visit_stmt_with_hint(&mut self, ctx: &mut Ctx, id: StatementId, hint: Option) -> VisitorResult { - if let Statement::Block(block_stmt) = &ctx.heap[id] { - let block_id = block_stmt.this; - self.visit_block_stmt_with_hint(ctx, block_id, hint) - } else { - self.visit_stmt(ctx, id) - } - } - - fn visit_block_stmt_with_hint(&mut self, ctx: &mut Ctx, id: BlockStatementId, hint: Option) -> VisitorResult { - if self.performing_breadth_pass { - // Performing a breadth pass, so don't traverse into the statements - // of the block. - return Ok(()) - } - - // Set parent scope and relative position in the parent scope. Remember - // these values to set them back to the old values when we're done with - // the traversal of the block's statements. - let body = &mut ctx.heap[id]; - body.parent_scope = self.cur_scope.clone(); - println!("DEBUG: Assigning relative {} to block {}", self.relative_pos_in_block, id.0.0.index); - body.relative_pos_in_parent = self.relative_pos_in_block; - - let old_scope = self.cur_scope.replace(match hint { - Some(sync_id) => Scope::Synchronous((sync_id, id)), - None => Scope::Regular(id), - }); - let old_relative_pos = self.relative_pos_in_block; - - // Copy statement IDs into buffer - let old_num_stmts = self.statement_buffer.len(); - { - let body = &ctx.heap[id]; - self.statement_buffer.extend_from_slice(&body.statements); - } - let new_num_stmts = self.statement_buffer.len(); - - // Perform the breadth-first pass. Its main purpose is to find labeled - // statements such that we can find the `goto`-targets immediately when - // performing the depth pass - self.performing_breadth_pass = true; - for stmt_idx in old_num_stmts..new_num_stmts { - self.relative_pos_in_block = (stmt_idx - old_num_stmts) as u32; - self.visit_stmt(ctx, self.statement_buffer[stmt_idx])?; - } - - if !self.insert_buffer.is_empty() { - let body = &mut ctx.heap[id]; - for (insert_idx, (pos, stmt)) in self.insert_buffer.drain(..).enumerate() { - body.statements.insert(pos as usize + insert_idx, stmt); - } - } - - // And the depth pass. Because we're not actually visiting any inserted - // nodes because we're using the statement buffer, we may safely use the - // relative_pos_in_block counter. - self.performing_breadth_pass = false; - for stmt_idx in old_num_stmts..new_num_stmts { - self.relative_pos_in_block = (stmt_idx - old_num_stmts) as u32; - self.visit_stmt(ctx, self.statement_buffer[stmt_idx])?; - } - - self.cur_scope = old_scope; - self.relative_pos_in_block = old_relative_pos; - - // Pop statement buffer - debug_assert!(self.insert_buffer.is_empty(), "insert buffer not empty after depth pass"); - self.statement_buffer.truncate(old_num_stmts); - - Ok(()) - } - - //-------------------------------------------------------------------------- - // Utilities - //-------------------------------------------------------------------------- - - /// Adds a local variable to the current scope. It will also annotate the - /// `Local` in the AST with its relative position in the block. - fn checked_local_add(&mut self, ctx: &mut Ctx, relative_pos: u32, id: LocalId) -> Result<(), ParseError2> { - debug_assert!(self.cur_scope.is_some()); - - // Make sure we do not conflict with any global symbols - { - let ident = &ctx.heap[id].identifier; - if let Some(symbol) = ctx.symbols.resolve_symbol(ctx.module.root_id, &ident.value) { - return Err( - ParseError2::new_error(&ctx.module.source, ident.position, "Local variable declaration conflicts with symbol") - .with_postfixed_info(&ctx.module.source, symbol.position, "Conflicting symbol is found here") - ); - } - } - - let local = &mut ctx.heap[id]; - local.relative_pos_in_block = relative_pos; - - // Make sure we do not shadow any variables in any of the scopes. Note - // that variables in parent scopes may be declared later - let local = &ctx.heap[id]; - let mut scope = self.cur_scope.as_ref().unwrap(); - let mut local_relative_pos = self.relative_pos_in_block; - - loop { - debug_assert!(scope.is_block(), "scope is not a block"); - let block = &ctx.heap[scope.to_block()]; - for other_local_id in &block.locals { - let other_local = &ctx.heap[*other_local_id]; - // Position check in case another variable with the same name - // is defined in a higher-level scope, but later than the scope - // in which the current variable resides. - if local.this != *other_local_id && - local_relative_pos >= other_local.relative_pos_in_block && - local.identifier.value == other_local.identifier.value { - // Collision within this scope - return Err( - ParseError2::new_error(&ctx.module.source, local.position, "Local variable name conflicts with another variable") - .with_postfixed_info(&ctx.module.source, other_local.position, "Previous variable is found here") - ); - } - } - - // Current scope is fine, move to parent scope if any - debug_assert!(block.parent_scope.is_some(), "block scope does not have a parent"); - scope = block.parent_scope.as_ref().unwrap(); - if let Scope::Definition(definition_id) = scope { - // At outer scope, check parameters of function/component - for parameter_id in ctx.heap[*definition_id].parameters() { - let parameter = &ctx.heap[*parameter_id]; - if local.identifier.value == parameter.identifier.value { - return Err( - ParseError2::new_error(&ctx.module.source, local.position, "Local variable name conflicts with parameter") - .with_postfixed_info(&ctx.module.source, parameter.position, "Parameter definition is found here") - ); - } - } - - break; - } - - // If here, then we are dealing with a block-like parent block - local_relative_pos = ctx.heap[scope.to_block()].relative_pos_in_parent; - } - - // No collisions at all - let block = &mut ctx.heap[self.cur_scope.as_ref().unwrap().to_block()]; - block.locals.push(id); - - Ok(()) - } - - /// Finds a variable in the visitor's scope that must appear before the - /// specified relative position within that block. - fn find_variable(&self, ctx: &Ctx, mut relative_pos: u32, identifier: &NamespacedIdentifier) -> Result { - debug_assert!(self.cur_scope.is_some()); - debug_assert!(identifier.num_namespaces > 0); - - // TODO: Update once globals are possible as well - if identifier.num_namespaces > 1 { - todo!("Implement namespaced constant seeking") - } - - // TODO: May still refer to an alias of a global symbol using a single - // identifier in the namespace. - // No need to use iterator over namespaces if here - let mut scope = self.cur_scope.as_ref().unwrap(); - println!(" *** DEBUG: Looking for {}, depth = {}", String::from_utf8_lossy(&identifier.value), self.performing_breadth_pass); - loop { - debug_assert!(scope.is_block()); - let block = &ctx.heap[scope.to_block()]; - println!("DEBUG: Looking in block {} with relative pos {}", block.this.0.0.index, relative_pos); - for local_id in &block.locals { - let local = &ctx.heap[*local_id]; - println!("DEBUG: Comparing against '{}' with relative pos {}", - String::from_utf8_lossy(&local.identifier.value), local.relative_pos_in_block); - if local.relative_pos_in_block < relative_pos && local.identifier.value == identifier.value { - return Ok(local_id.upcast()); - } - } - - debug_assert!(block.parent_scope.is_some()); - scope = block.parent_scope.as_ref().unwrap(); - if !scope.is_block() { - // Definition scope, need to check arguments to definition - println!("DEBUG: Looking in definition scope now..."); - match scope { - Scope::Definition(definition_id) => { - let definition = &ctx.heap[*definition_id]; - for parameter_id in definition.parameters() { - let parameter = &ctx.heap[*parameter_id]; - if parameter.identifier.value == identifier.value { - return Ok(parameter_id.upcast()); - } - } - }, - _ => unreachable!(), - } - - // Variable could not be found - return Err(ParseError2::new_error( - &ctx.module.source, identifier.position, "This variable is not declared" - )); - } else { - relative_pos = block.relative_pos_in_parent; - } - } - } - - /// Adds a particular label to the current scope. Will return an error if - /// there is another label with the same name visible in the current scope. - fn checked_label_add(&mut self, ctx: &mut Ctx, id: LabeledStatementId) -> Result<(), ParseError2> { - debug_assert!(self.cur_scope.is_some()); - - // Make sure label is not defined within the current scope or any of the - // parent scope. - let label = &ctx.heap[id]; - let mut scope = self.cur_scope.as_ref().unwrap(); - - loop { - debug_assert!(scope.is_block(), "scope is not a block"); - let block = &ctx.heap[scope.to_block()]; - for other_label_id in &block.labels { - let other_label = &ctx.heap[*other_label_id]; - if other_label.label.value == label.label.value { - // Collision - return Err( - ParseError2::new_error(&ctx.module.source, label.position, "Label name conflicts with another label") - .with_postfixed_info(&ctx.module.source, other_label.position, "Other label is found here") - ); - } - } - - debug_assert!(block.parent_scope.is_some(), "block scope does not have a parent"); - scope = block.parent_scope.as_ref().unwrap(); - if !scope.is_block() { - break; - } - } - - // No collisions - let block = &mut ctx.heap[self.cur_scope.as_ref().unwrap().to_block()]; - block.labels.push(id); - - Ok(()) - } - - /// Finds a particular labeled statement by its identifier. Once found it - /// will make sure that the target label does not skip over any variable - /// declarations within the scope in which the label was found. - fn find_label(&self, ctx: &Ctx, identifier: &Identifier) -> Result { - debug_assert!(self.cur_scope.is_some()); - - let mut scope = self.cur_scope.as_ref().unwrap(); - loop { - debug_assert!(scope.is_block(), "scope is not a block"); - let relative_scope_pos = ctx.heap[scope.to_block()].relative_pos_in_parent; - - let block = &ctx.heap[scope.to_block()]; - for label_id in &block.labels { - let label = &ctx.heap[*label_id]; - if label.label.value == identifier.value { - for local_id in &block.locals { - // TODO: Better to do this in control flow analysis, it - // is legal to skip over a variable declaration if it - // is not actually being used. I might be missing - // something here when laying out the bytecode... - let local = &ctx.heap[*local_id]; - if local.relative_pos_in_block > relative_scope_pos && local.relative_pos_in_block < label.relative_pos_in_block { - return Err( - ParseError2::new_error(&ctx.module.source, identifier.position, "This target label skips over a variable declaration") - .with_postfixed_info(&ctx.module.source, label.position, "Because it jumps to this label") - .with_postfixed_info(&ctx.module.source, local.position, "Which skips over this variable") - ); - } - } - return Ok(*label_id); - } - } - - debug_assert!(block.parent_scope.is_some(), "block scope does not have a parent"); - scope = block.parent_scope.as_ref().unwrap(); - if !scope.is_block() { - return Err(ParseError2::new_error(&ctx.module.source, identifier.position, "Could not find this label")); - } - - } - } - - /// Finds a particular symbol in the symbol table which must correspond to - /// a definition of a particular type. - // Note: root_id, symbols and types passed in explicitly to prevent - // borrowing errors - fn find_symbol_of_type( - &self, root_id: RootId, symbols: &SymbolTable, types: &TypeTable, - identifier: &NamespacedIdentifier, expected_type_class: TypeClass - ) -> FindOfTypeResult { - // Find symbol associated with identifier - let symbol = symbols.resolve_namespaced_symbol(root_id, &identifier); - if symbol.is_none() { return FindOfTypeResult::NotFound; } - - let (symbol, iter) = symbol.unwrap(); - if iter.num_remaining() != 0 { return FindOfTypeResult::NotFound; } - - match &symbol.symbol { - Symbol::Definition((_, definition_id)) => { - // Make sure definition is of the expected type - let definition_type = types.get_definition(definition_id); - debug_assert!(definition_type.is_some(), "Found symbol '{}' in symbol table, but not in type table", String::from_utf8_lossy(&identifier.value)); - let definition_type_class = definition_type.unwrap().type_class(); - - if definition_type_class != expected_type_class { - FindOfTypeResult::TypeMismatch(definition_type_class.display_name()) - } else { - FindOfTypeResult::Found(*definition_id) - } - }, - Symbol::Namespace(_) => FindOfTypeResult::TypeMismatch("namespace"), - } - } - - /// This function will check if the provided while statement ID has a block - /// statement that is one of our current parents. - fn has_parent_while_scope(&self, ctx: &Ctx, id: WhileStatementId) -> bool { - debug_assert!(self.cur_scope.is_some()); - let mut scope = self.cur_scope.as_ref().unwrap(); - let while_stmt = &ctx.heap[id]; - loop { - debug_assert!(scope.is_block()); - let block = scope.to_block(); - if while_stmt.body == block.upcast() { - return true; - } - - let block = &ctx.heap[block]; - debug_assert!(block.parent_scope.is_some(), "block scope does not have a parent"); - scope = block.parent_scope.as_ref().unwrap(); - if !scope.is_block() { - return false; - } - } - } - - /// This function should be called while dealing with break/continue - /// statements. It will try to find the targeted while statement, using the - /// target label if provided. If a valid target is found then the loop's - /// ID will be returned, otherwise a parsing error is constructed. - /// The provided input position should be the position of the break/continue - /// statement. - fn resolve_break_or_continue_target(&self, ctx: &Ctx, position: InputPosition, label: &Option) -> Result { - let target = match label { - Some(label) => { - let target_id = self.find_label(ctx, label)?; - - // Make sure break target is a while statement - let target = &ctx.heap[target_id]; - if let Statement::While(target_stmt) = &ctx.heap[target.body] { - // Even though we have a target while statement, the break might not be - // present underneath this particular labeled while statement - if !self.has_parent_while_scope(ctx, target_stmt.this) { - ParseError2::new_error(&ctx.module.source, label.position, "Break statement is not nested under the target label's while statement") - .with_postfixed_info(&ctx.module.source, target.position, "The targeted label is found here"); - } - - target_stmt.this - } else { - return Err( - ParseError2::new_error(&ctx.module.source, label.position, "Incorrect break target label, it must target a while loop") - .with_postfixed_info(&ctx.module.source, target.position, "The targeted label is found here") - ); - } - }, - None => { - // Use the enclosing while statement, the break must be - // nested within that while statement - if self.in_while.is_none() { - return Err( - ParseError2::new_error(&ctx.module.source, position, "Break statement is not nested under a while loop") - ); - } - - self.in_while.unwrap() - } - }; - - // We have a valid target for the break statement. But we need to - // make sure we will not break out of a synchronous block - { - let target_while = &ctx.heap[target]; - if target_while.in_sync != self.in_sync { - // Break is nested under while statement, so can only escape a - // sync block if the sync is nested inside the while statement. - debug_assert!(self.in_sync.is_some()); - let sync_stmt = &ctx.heap[self.in_sync.unwrap()]; - return Err( - ParseError2::new_error(&ctx.module.source, position, "Break may not escape the surrounding synchronous block") - .with_postfixed_info(&ctx.module.source, target_while.position, "The break escapes out of this loop") - .with_postfixed_info(&ctx.module.source, sync_stmt.position, "And would therefore escape this synchronous block") - ); - } - } - - Ok(target) - } } \ No newline at end of file diff --git a/src/protocol/parser/visitor_linker.rs b/src/protocol/parser/visitor_linker.rs new file mode 100644 index 0000000000000000000000000000000000000000..8a05332e0e5a85283b79e376d602260dcb788b86 --- /dev/null +++ b/src/protocol/parser/visitor_linker.rs @@ -0,0 +1,1080 @@ +use crate::protocol::ast::*; +use crate::protocol::inputsource::*; +use crate::protocol::parser::{symbol_table::*, type_table::*}; + +use super::visitor::{ + STMT_BUFFER_INIT_CAPACITY, + EXPR_BUFFER_INIT_CAPACITY, + Ctx, + Visitor2, + VisitorResult +}; + +#[derive(PartialEq, Eq)] +enum DefinitionType { + Primitive, + Composite, + Function +} + +/// This particular visitor will go through the entire AST in a recursive manner +/// and check if all statements and expressions are legal (e.g. no "return" +/// statements in component definitions), and will link certain AST nodes to +/// their appropriate targets (e.g. goto statements, or function calls). +/// +/// This visitor will not perform control-flow analysis (e.g. making sure that +/// each function actually returns) and will also not perform type checking. So +/// the linking of function calls and component instantiations will be checked +/// and linked to the appropriate definitions, but the return types and/or +/// arguments will not be checked for validity. +/// +/// The visitor visits each statement in a block in a breadth-first manner +/// first. We are thereby sure that we have found all variables/labels in a +/// particular block. In this phase nodes may queue statements for insertion +/// (e.g. the insertion of an `EndIf` statement for a particular `If` +/// statement). These will be inserted after visiting every node, after which +/// the visitor recurses into each statement in a block. +/// +/// Because of this scheme expressions will not be visited in the breadth-first +/// pass. +pub(crate) struct ValidityAndLinkerVisitor { + /// `in_sync` is `Some(id)` if the visitor is visiting the children of a + /// synchronous statement. A single value is sufficient as nested + /// synchronous statements are not allowed + in_sync: Option, + /// `in_while` contains the last encountered `While` statement. This is used + /// to resolve unlabeled `Continue`/`Break` statements. + in_while: Option, + // Traversal state: current scope (which can be used to find the parent + // scope), the definition variant we are considering, and whether the + // visitor is performing breadthwise block statement traversal. + cur_scope: Option, + def_type: DefinitionType, + performing_breadth_pass: bool, + // Keeping track of relative position in block in the breadth-first pass. + // May not correspond to block.statement[index] if any statements are + // inserted after the breadth-pass + relative_pos_in_block: u32, + // Single buffer of statement IDs that we want to traverse in a block. + // Required to work around Rust borrowing rules and to prevent constant + // cloning of vectors. + statement_buffer: Vec, + // Another buffer, now with expression IDs, to prevent constant cloning of + // vectors while working around rust's borrowing rules + expression_buffer: Vec, + // Statements to insert after the breadth pass in a single block + insert_buffer: Vec<(u32, StatementId)>, +} + +impl ValidityAndLinkerVisitor { + pub(crate) fn new() -> Self { + Self{ + in_sync: None, + in_while: None, + cur_scope: None, + def_type: DefinitionType::Primitive, + performing_breadth_pass: false, + relative_pos_in_block: 0, + statement_buffer: Vec::with_capacity(STMT_BUFFER_INIT_CAPACITY), + expression_buffer: Vec::with_capacity(EXPR_BUFFER_INIT_CAPACITY), + insert_buffer: Vec::with_capacity(32), + } + } + + fn reset_state(&mut self) { + self.in_sync = None; + self.in_while = None; + self.cur_scope = None; + self.def_type = DefinitionType::Primitive; + self.relative_pos_in_block = 0; + self.performing_breadth_pass = false; + self.statement_buffer.clear(); + self.insert_buffer.clear(); + } +} + +impl Visitor2 for ValidityAndLinkerVisitor { + //-------------------------------------------------------------------------- + // Definition visitors + //-------------------------------------------------------------------------- + + fn visit_component_definition(&mut self, ctx: &mut Ctx, id: ComponentId) -> VisitorResult { + self.reset_state(); + + self.def_type = match &ctx.heap[id].variant { + ComponentVariant::Primitive => DefinitionType::Primitive, + ComponentVariant::Composite => DefinitionType::Composite, + }; + self.cur_scope = Some(Scope::Definition(id.upcast())); + let body_id = ctx.heap[id].body; + + self.performing_breadth_pass = true; + self.visit_stmt(ctx, body_id)?; + self.performing_breadth_pass = false; + self.visit_stmt(ctx, body_id) + } + + fn visit_function_definition(&mut self, ctx: &mut Ctx, id: FunctionId) -> VisitorResult { + self.reset_state(); + + // Set internal statement indices + self.def_type = DefinitionType::Function; + self.cur_scope = Some(Scope::Definition(id.upcast())); + let body_id = ctx.heap[id].body; + + self.performing_breadth_pass = true; + self.visit_stmt(ctx, body_id)?; + self.performing_breadth_pass = false; + self.visit_stmt(ctx, body_id) + } + + //-------------------------------------------------------------------------- + // Statement visitors + //-------------------------------------------------------------------------- + + fn visit_block_stmt(&mut self, ctx: &mut Ctx, id: BlockStatementId) -> VisitorResult { + self.visit_block_stmt_with_hint(ctx, id, None) + } + + fn visit_local_memory_stmt(&mut self, ctx: &mut Ctx, id: MemoryStatementId) -> VisitorResult { + if self.performing_breadth_pass { + let variable_id = ctx.heap[id].variable; + self.checked_local_add(ctx, self.relative_pos_in_block, variable_id)?; + } else { + self.visit_expr(ctx, ctx.heap[id].initial)?; + } + + Ok(()) + } + + fn visit_local_channel_stmt(&mut self, ctx: &mut Ctx, id: ChannelStatementId) -> VisitorResult { + if self.performing_breadth_pass { + let (from_id, to_id) = { + let stmt = &ctx.heap[id]; + (stmt.from, stmt.to) + }; + self.checked_local_add(ctx, self.relative_pos_in_block, from_id)?; + self.checked_local_add(ctx, self.relative_pos_in_block, to_id)?; + } + + Ok(()) + } + + fn visit_labeled_stmt(&mut self, ctx: &mut Ctx, id: LabeledStatementId) -> VisitorResult { + if self.performing_breadth_pass { + // Add label to block lookup + self.checked_label_add(ctx, id)?; + + // Modify labeled statement itself + let labeled = &mut ctx.heap[id]; + labeled.relative_pos_in_block = self.relative_pos_in_block; + labeled.in_sync = self.in_sync.clone(); + } + + let body_id = ctx.heap[id].body; + self.visit_stmt(ctx, body_id)?; + + Ok(()) + } + + fn visit_if_stmt(&mut self, ctx: &mut Ctx, id: IfStatementId) -> VisitorResult { + if self.performing_breadth_pass { + let position = ctx.heap[id].position; + let end_if_id = ctx.heap.alloc_end_if_statement(|this| { + EndIfStatement { + this, + start_if: id, + position, + next: None, + } + }); + let stmt = &mut ctx.heap[id]; + stmt.end_if = Some(end_if_id); + self.insert_buffer.push((self.relative_pos_in_block + 1, end_if_id.upcast())); + } else { + // Traverse expression and bodies + let (test_id, true_id, false_id) = { + let stmt = &ctx.heap[id]; + (stmt.test, stmt.true_body, stmt.false_body) + }; + self.visit_expr(ctx, test_id)?; + self.visit_stmt(ctx, true_id)?; + self.visit_stmt(ctx, false_id)?; + } + + Ok(()) + } + + fn visit_while_stmt(&mut self, ctx: &mut Ctx, id: WhileStatementId) -> VisitorResult { + if self.performing_breadth_pass { + let position = ctx.heap[id].position; + let end_while_id = ctx.heap.alloc_end_while_statement(|this| { + EndWhileStatement { + this, + start_while: id, + position, + next: None, + } + }); + let stmt = &mut ctx.heap[id]; + stmt.end_while = Some(end_while_id); + stmt.in_sync = self.in_sync.clone(); + + self.insert_buffer.push((self.relative_pos_in_block + 1, end_while_id.upcast())); + } else { + let (test_id, body_id) = { + let stmt = &ctx.heap[id]; + (stmt.test, stmt.body) + }; + let old_while = self.in_while.replace(id); + self.visit_expr(ctx, test_id)?; + self.visit_stmt(ctx, body_id)?; + self.in_while = old_while; + } + + Ok(()) + } + + fn visit_break_stmt(&mut self, ctx: &mut Ctx, id: BreakStatementId) -> VisitorResult { + if self.performing_breadth_pass { + // Should be able to resolve break statements with a label in the + // breadth pass, no need to do after resolving all labels + let target_end_while = { + let stmt = &ctx.heap[id]; + let target_while_id = self.resolve_break_or_continue_target(ctx, stmt.position, &stmt.label)?; + let target_while = &ctx.heap[target_while_id]; + debug_assert!(target_while.end_while.is_some()); + target_while.end_while.unwrap() + }; + + let stmt = &mut ctx.heap[id]; + stmt.target = Some(target_end_while); + } + + Ok(()) + } + + fn visit_continue_stmt(&mut self, ctx: &mut Ctx, id: ContinueStatementId) -> VisitorResult { + if self.performing_breadth_pass { + let target_while_id = { + let stmt = &ctx.heap[id]; + self.resolve_break_or_continue_target(ctx, stmt.position, &stmt.label)? + }; + + let stmt = &mut ctx.heap[id]; + stmt.target = Some(target_while_id) + } + + Ok(()) + } + + fn visit_synchronous_stmt(&mut self, ctx: &mut Ctx, id: SynchronousStatementId) -> VisitorResult { + if self.performing_breadth_pass { + // Check for validity of synchronous statement + let cur_sync_position = ctx.heap[id].position; + if self.in_sync.is_some() { + // Nested synchronous statement + let old_sync = &ctx.heap[self.in_sync.unwrap()]; + return Err( + ParseError2::new_error(&ctx.module.source, cur_sync_position, "Illegal nested synchronous statement") + .with_postfixed_info(&ctx.module.source, old_sync.position, "It is nested in this synchronous statement") + ); + } + + if self.def_type != DefinitionType::Primitive { + return Err(ParseError2::new_error( + &ctx.module.source, cur_sync_position, + "Synchronous statements may only be used in primitive components" + )); + } + + // Append SynchronousEnd pseudo-statement + let sync_end_id = ctx.heap.alloc_end_synchronous_statement(|this| EndSynchronousStatement{ + this, + position: cur_sync_position, + start_sync: id, + next: None, + }); + let sync_start = &mut ctx.heap[id]; + sync_start.end_sync = Some(sync_end_id); + self.insert_buffer.push((self.relative_pos_in_block + 1, sync_end_id.upcast())); + } else { + let sync_body = ctx.heap[id].body; + let old = self.in_sync.replace(id); + self.visit_stmt_with_hint(ctx, sync_body, Some(id))?; + self.in_sync = old; + } + + Ok(()) + } + + fn visit_return_stmt(&mut self, ctx: &mut Ctx, id: ReturnStatementId) -> VisitorResult { + if self.performing_breadth_pass { + let stmt = &ctx.heap[id]; + if self.def_type != DefinitionType::Function { + return Err( + ParseError2::new_error(&ctx.module.source, stmt.position, "Return statements may only appear in function bodies") + ); + } + } else { + // If here then we are within a function + self.visit_expr(ctx, ctx.heap[id].expression)?; + } + + Ok(()) + } + + fn visit_assert_stmt(&mut self, ctx: &mut Ctx, id: AssertStatementId) -> VisitorResult { + let stmt = &ctx.heap[id]; + if self.performing_breadth_pass { + if self.def_type == DefinitionType::Function { + // TODO: We probably want to allow this. Mark the function as + // using asserts, and then only allow calls to these functions + // within components. Such a marker will cascade through any + // functions that then call an asserting function + return Err( + ParseError2::new_error(&ctx.module.source, stmt.position, "Illegal assert statement in a function") + ); + } + + // We are in a component of some sort, but we also need to be within a + // synchronous statement + if self.in_sync.is_none() { + return Err( + ParseError2::new_error(&ctx.module.source, stmt.position, "Illegal assert statement outside of a synchronous block") + ); + } + } else { + let expr_id = stmt.expression; + self.visit_expr(ctx, expr_id)?; + } + + Ok(()) + } + + fn visit_goto_stmt(&mut self, ctx: &mut Ctx, id: GotoStatementId) -> VisitorResult { + if !self.performing_breadth_pass { + // Must perform goto label resolving after the breadth pass, this + // way we are able to find all the labels in current and outer + // scopes. + let target_id = self.find_label(ctx, &ctx.heap[id].label)?; + ctx.heap[id].target = Some(target_id); + + let target = &ctx.heap[target_id]; + if self.in_sync != target.in_sync { + // We can only goto the current scope or outer scopes. Because + // nested sync statements are not allowed so if the value does + // not match, then we must be inside a sync scope + debug_assert!(self.in_sync.is_some()); + let goto_stmt = &ctx.heap[id]; + let sync_stmt = &ctx.heap[self.in_sync.unwrap()]; + return Err( + ParseError2::new_error(&ctx.module.source, goto_stmt.position, "Goto may not escape the surrounding synchronous block") + .with_postfixed_info(&ctx.module.source, target.position, "This is the target of the goto statement") + .with_postfixed_info(&ctx.module.source, sync_stmt.position, "Which will jump past this statement") + ); + } + } + + Ok(()) + } + + fn visit_new_stmt(&mut self, ctx: &mut Ctx, id: NewStatementId) -> VisitorResult { + // Link the call expression following the new statement + if self.performing_breadth_pass { + // TODO: Cleanup error messages, can be done cleaner + // Make sure new statement occurs within a composite component + let call_expr_id = ctx.heap[id].expression; + if self.def_type != DefinitionType::Composite { + let new_stmt = &ctx.heap[id]; + return Err( + ParseError2::new_error(&ctx.module.source, new_stmt.position, "Instantiating components may only be done in composite components") + ); + } + + // No fancy recursive parsing, must be followed by a call expression + let definition_id = { + let call_expr = &ctx.heap[call_expr_id]; + if let Method::Symbolic(symbolic) = &call_expr.method { + let found_symbol = self.find_symbol_of_type( + ctx.module.root_id, &ctx.symbols, &ctx.types, + &symbolic.identifier, TypeClass::Component + ); + + match found_symbol { + FindOfTypeResult::Found(definition_id) => definition_id, + FindOfTypeResult::TypeMismatch(got_type_class) => { + return Err(ParseError2::new_error( + &ctx.module.source, symbolic.identifier.position, + &format!("New must instantiate a component, this identifier points to a {}", got_type_class) + )) + }, + FindOfTypeResult::NotFound => { + return Err(ParseError2::new_error( + &ctx.module.source, symbolic.identifier.position, + "Could not find a defined component with this name" + )) + } + } + } else { + return Err( + ParseError2::new_error(&ctx.module.source, call_expr.position, "Must instantiate a component") + ); + } + }; + + // Modify new statement's symbolic call to point to the appropriate + // definition. + let call_expr = &mut ctx.heap[call_expr_id]; + match &mut call_expr.method { + Method::Symbolic(method) => method.definition = Some(definition_id), + _ => unreachable!() + } + } else { + // Performing depth pass. The function definition should have been + // resolved in the breadth pass, now we recurse to evaluate the + // arguments + let call_expr_id = ctx.heap[id].expression; + let call_expr = &ctx.heap[call_expr_id]; + + let old_num_exprs = self.expression_buffer.len(); + self.expression_buffer.extend(&call_expr.arguments); + let new_num_exprs = self.expression_buffer.len(); + + for arg_expr_idx in old_num_exprs..new_num_exprs { + let arg_expr_id = self.expression_buffer[arg_expr_idx]; + self.visit_expr(ctx, arg_expr_id)?; + } + + self.expression_buffer.truncate(old_num_exprs); + } + + Ok(()) + } + + fn visit_put_stmt(&mut self, ctx: &mut Ctx, id: PutStatementId) -> VisitorResult { + // TODO: Make `put` an expression. Perhaps silly, but much easier to + // perform typechecking + if self.performing_breadth_pass { + let put_stmt = &ctx.heap[id]; + if self.in_sync.is_none() { + return Err(ParseError2::new_error( + &ctx.module.source, put_stmt.position, "Put must be called in a synchronous block" + )); + } + } else { + let put_stmt = &ctx.heap[id]; + let port = put_stmt.port; + let message = put_stmt.message; + self.visit_expr(ctx, port)?; + self.visit_expr(ctx, message)?; + } + + Ok(()) + } + + fn visit_expr_stmt(&mut self, ctx: &mut Ctx, id: ExpressionStatementId) -> VisitorResult { + if !self.performing_breadth_pass { + let expr_id = ctx.heap[id].expression; + self.visit_expr(ctx, expr_id)?; + } + + Ok(()) + } + + + //-------------------------------------------------------------------------- + // Expression visitors + //-------------------------------------------------------------------------- + + fn visit_assignment_expr(&mut self, ctx: &mut Ctx, id: AssignmentExpressionId) -> VisitorResult { + debug_assert!(!self.performing_breadth_pass); + let assignment_expr = &ctx.heap[id]; + let left_expr_id = assignment_expr.left; + let right_expr_id = assignment_expr.right; + self.visit_expr(ctx, left_expr_id)?; + self.visit_expr(ctx, right_expr_id)?; + Ok(()) + } + + fn visit_conditional_expr(&mut self, ctx: &mut Ctx, id: ConditionalExpressionId) -> VisitorResult { + debug_assert!(!self.performing_breadth_pass); + let conditional_expr = &ctx.heap[id]; + let test_expr_id = conditional_expr.test; + let true_expr_id = conditional_expr.true_expression; + let false_expr_id = conditional_expr.false_expression; + self.visit_expr(ctx, test_expr_id)?; + self.visit_expr(ctx, true_expr_id)?; + self.visit_expr(ctx, false_expr_id)?; + Ok(()) + } + + fn visit_binary_expr(&mut self, ctx: &mut Ctx, id: BinaryExpressionId) -> VisitorResult { + debug_assert!(!self.performing_breadth_pass); + let binary_expr = &ctx.heap[id]; + let left_expr_id = binary_expr.left; + let right_expr_id = binary_expr.right; + self.visit_expr(ctx, left_expr_id)?; + self.visit_expr(ctx, right_expr_id)?; + Ok(()) + } + + fn visit_unary_expr(&mut self, ctx: &mut Ctx, id: UnaryExpressionId) -> VisitorResult { + debug_assert!(!self.performing_breadth_pass); + let expr_id = ctx.heap[id].expression; + self.visit_expr(ctx, expr_id)?; + Ok(()) + } + + fn visit_indexing_expr(&mut self, ctx: &mut Ctx, id: IndexingExpressionId) -> VisitorResult { + debug_assert!(!self.performing_breadth_pass); + let indexing_expr = &ctx.heap[id]; + let subject_expr_id = indexing_expr.subject; + let index_expr_id = indexing_expr.index; + self.visit_expr(ctx, subject_expr_id)?; + self.visit_expr(ctx, index_expr_id)?; + Ok(()) + } + + fn visit_slicing_expr(&mut self, ctx: &mut Ctx, id: SlicingExpressionId) -> VisitorResult { + debug_assert!(!self.performing_breadth_pass); + // TODO: Same as the select expression: slicing depends on the type of + // the thing that is being sliced. + let slicing_expr = &ctx.heap[id]; + let subject_expr_id = slicing_expr.subject; + let from_expr_id = slicing_expr.from_index; + let to_expr_id = slicing_expr.to_index; + self.visit_expr(ctx, subject_expr_id)?; + self.visit_expr(ctx, from_expr_id)?; + self.visit_expr(ctx, to_expr_id)?; + Ok(()) + } + + fn visit_select_expr(&mut self, ctx: &mut Ctx, id: SelectExpressionId) -> VisitorResult { + debug_assert!(!self.performing_breadth_pass); + let expr_id = ctx.heap[id].subject; + self.visit_expr(ctx, expr_id)?; + + Ok(()) + } + + fn visit_array_expr(&mut self, ctx: &mut Ctx, id: ArrayExpressionId) -> VisitorResult { + debug_assert!(!self.performing_breadth_pass); + let array_expr = &ctx.heap[id]; + + let old_num_exprs = self.expression_buffer.len(); + self.expression_buffer.extend(&array_expr.elements); + let new_num_exprs = self.expression_buffer.len(); + + for field_expr_idx in old_num_exprs..new_num_exprs { + let field_expr_id = self.expression_buffer[field_expr_idx]; + self.visit_expr(ctx, field_expr_id)?; + } + + self.expression_buffer.truncate(old_num_exprs); + + Ok(()) + } + + fn visit_constant_expr(&mut self, _ctx: &mut Ctx, _id: ConstantExpressionId) -> VisitorResult { + debug_assert!(!self.performing_breadth_pass); + Ok(()) + } + + fn visit_call_expr(&mut self, ctx: &mut Ctx, id: CallExpressionId) -> VisitorResult { + debug_assert!(!self.performing_breadth_pass); + + let call_expr = &mut ctx.heap[id]; + + // Resolve the method to the appropriate definition and check the + // legality of the particular method call. + match &mut call_expr.method { + Method::Create => {}, + Method::Fires => { + if self.def_type != DefinitionType::Primitive { + return Err(ParseError2::new_error( + &ctx.module.source, call_expr.position, + "A call to 'fires' may only occur in primitive component definitions" + )); + } + }, + Method::Get => { + if self.def_type != DefinitionType::Primitive { + return Err(ParseError2::new_error( + &ctx.module.source, call_expr.position, + "A call to 'get' may only occur in primitive component definitions" + )); + } + }, + Method::Symbolic(symbolic) => { + // Find symbolic method + let found_symbol = self.find_symbol_of_type( + ctx.module.root_id, &ctx.symbols, &ctx.types, + &symbolic.identifier, TypeClass::Function + ); + let definition_id = match found_symbol { + FindOfTypeResult::Found(definition_id) => definition_id, + FindOfTypeResult::TypeMismatch(got_type_class) => { + return Err(ParseError2::new_error( + &ctx.module.source, symbolic.identifier.position, + &format!("Only functions can be called, this identifier points to a {}", got_type_class) + )) + }, + FindOfTypeResult::NotFound => { + return Err(ParseError2::new_error( + &ctx.module.source, symbolic.identifier.position, + &format!("Could not find a function with this name") + )) + } + }; + + symbolic.definition = Some(definition_id); + } + } + + // Parse all the arguments in the depth pass as well + let call_expr = &mut ctx.heap[id]; + let old_num_exprs = self.expression_buffer.len(); + self.expression_buffer.extend(&call_expr.arguments); + let new_num_exprs = self.expression_buffer.len(); + + for arg_expr_idx in old_num_exprs..new_num_exprs { + let arg_expr_id = self.expression_buffer[arg_expr_idx]; + self.visit_expr(ctx, arg_expr_id)?; + } + + self.expression_buffer.truncate(old_num_exprs); + + Ok(()) + } + + fn visit_variable_expr(&mut self, ctx: &mut Ctx, id: VariableExpressionId) -> VisitorResult { + debug_assert!(!self.performing_breadth_pass); + + let var_expr = &ctx.heap[id]; + let variable_id = self.find_variable(ctx, self.relative_pos_in_block, &var_expr.identifier)?; + let var_expr = &mut ctx.heap[id]; + var_expr.declaration = Some(variable_id); + + Ok(()) + } +} + +enum FindOfTypeResult { + // Identifier was exactly matched, type matched as well + Found(DefinitionId), + // Identifier was matched, but the type differs from the expected one + TypeMismatch(&'static str), + // Identifier could not be found + NotFound, +} + +impl ValidityAndLinkerVisitor { + //-------------------------------------------------------------------------- + // Special traversal + //-------------------------------------------------------------------------- + + /// Will visit a statement with a hint about its wrapping statement. This is + /// used to distinguish block statements with a wrapping synchronous + /// statement from normal block statements. + fn visit_stmt_with_hint(&mut self, ctx: &mut Ctx, id: StatementId, hint: Option) -> VisitorResult { + if let Statement::Block(block_stmt) = &ctx.heap[id] { + let block_id = block_stmt.this; + self.visit_block_stmt_with_hint(ctx, block_id, hint) + } else { + self.visit_stmt(ctx, id) + } + } + + fn visit_block_stmt_with_hint(&mut self, ctx: &mut Ctx, id: BlockStatementId, hint: Option) -> VisitorResult { + if self.performing_breadth_pass { + // Performing a breadth pass, so don't traverse into the statements + // of the block. + return Ok(()) + } + + // Set parent scope and relative position in the parent scope. Remember + // these values to set them back to the old values when we're done with + // the traversal of the block's statements. + let body = &mut ctx.heap[id]; + body.parent_scope = self.cur_scope.clone(); + println!("DEBUG: Assigning relative {} to block {}", self.relative_pos_in_block, id.0.0.index); + body.relative_pos_in_parent = self.relative_pos_in_block; + + let old_scope = self.cur_scope.replace(match hint { + Some(sync_id) => Scope::Synchronous((sync_id, id)), + None => Scope::Regular(id), + }); + let old_relative_pos = self.relative_pos_in_block; + + // Copy statement IDs into buffer + let old_num_stmts = self.statement_buffer.len(); + { + let body = &ctx.heap[id]; + self.statement_buffer.extend_from_slice(&body.statements); + } + let new_num_stmts = self.statement_buffer.len(); + + // Perform the breadth-first pass. Its main purpose is to find labeled + // statements such that we can find the `goto`-targets immediately when + // performing the depth pass + self.performing_breadth_pass = true; + for stmt_idx in old_num_stmts..new_num_stmts { + self.relative_pos_in_block = (stmt_idx - old_num_stmts) as u32; + self.visit_stmt(ctx, self.statement_buffer[stmt_idx])?; + } + + if !self.insert_buffer.is_empty() { + let body = &mut ctx.heap[id]; + for (insert_idx, (pos, stmt)) in self.insert_buffer.drain(..).enumerate() { + body.statements.insert(pos as usize + insert_idx, stmt); + } + } + + // And the depth pass. Because we're not actually visiting any inserted + // nodes because we're using the statement buffer, we may safely use the + // relative_pos_in_block counter. + self.performing_breadth_pass = false; + for stmt_idx in old_num_stmts..new_num_stmts { + self.relative_pos_in_block = (stmt_idx - old_num_stmts) as u32; + self.visit_stmt(ctx, self.statement_buffer[stmt_idx])?; + } + + self.cur_scope = old_scope; + self.relative_pos_in_block = old_relative_pos; + + // Pop statement buffer + debug_assert!(self.insert_buffer.is_empty(), "insert buffer not empty after depth pass"); + self.statement_buffer.truncate(old_num_stmts); + + Ok(()) + } + + //-------------------------------------------------------------------------- + // Utilities + //-------------------------------------------------------------------------- + + /// Adds a local variable to the current scope. It will also annotate the + /// `Local` in the AST with its relative position in the block. + fn checked_local_add(&mut self, ctx: &mut Ctx, relative_pos: u32, id: LocalId) -> Result<(), ParseError2> { + debug_assert!(self.cur_scope.is_some()); + + // Make sure we do not conflict with any global symbols + { + let ident = &ctx.heap[id].identifier; + if let Some(symbol) = ctx.symbols.resolve_symbol(ctx.module.root_id, &ident.value) { + return Err( + ParseError2::new_error(&ctx.module.source, ident.position, "Local variable declaration conflicts with symbol") + .with_postfixed_info(&ctx.module.source, symbol.position, "Conflicting symbol is found here") + ); + } + } + + let local = &mut ctx.heap[id]; + local.relative_pos_in_block = relative_pos; + + // Make sure we do not shadow any variables in any of the scopes. Note + // that variables in parent scopes may be declared later + let local = &ctx.heap[id]; + let mut scope = self.cur_scope.as_ref().unwrap(); + let mut local_relative_pos = self.relative_pos_in_block; + + loop { + debug_assert!(scope.is_block(), "scope is not a block"); + let block = &ctx.heap[scope.to_block()]; + for other_local_id in &block.locals { + let other_local = &ctx.heap[*other_local_id]; + // Position check in case another variable with the same name + // is defined in a higher-level scope, but later than the scope + // in which the current variable resides. + if local.this != *other_local_id && + local_relative_pos >= other_local.relative_pos_in_block && + local.identifier.value == other_local.identifier.value { + // Collision within this scope + return Err( + ParseError2::new_error(&ctx.module.source, local.position, "Local variable name conflicts with another variable") + .with_postfixed_info(&ctx.module.source, other_local.position, "Previous variable is found here") + ); + } + } + + // Current scope is fine, move to parent scope if any + debug_assert!(block.parent_scope.is_some(), "block scope does not have a parent"); + scope = block.parent_scope.as_ref().unwrap(); + if let Scope::Definition(definition_id) = scope { + // At outer scope, check parameters of function/component + for parameter_id in ctx.heap[*definition_id].parameters() { + let parameter = &ctx.heap[*parameter_id]; + if local.identifier.value == parameter.identifier.value { + return Err( + ParseError2::new_error(&ctx.module.source, local.position, "Local variable name conflicts with parameter") + .with_postfixed_info(&ctx.module.source, parameter.position, "Parameter definition is found here") + ); + } + } + + break; + } + + // If here, then we are dealing with a block-like parent block + local_relative_pos = ctx.heap[scope.to_block()].relative_pos_in_parent; + } + + // No collisions at all + let block = &mut ctx.heap[self.cur_scope.as_ref().unwrap().to_block()]; + block.locals.push(id); + + Ok(()) + } + + /// Finds a variable in the visitor's scope that must appear before the + /// specified relative position within that block. + fn find_variable(&self, ctx: &Ctx, mut relative_pos: u32, identifier: &NamespacedIdentifier) -> Result { + debug_assert!(self.cur_scope.is_some()); + debug_assert!(identifier.num_namespaces > 0); + + // TODO: Update once globals are possible as well + if identifier.num_namespaces > 1 { + todo!("Implement namespaced constant seeking") + } + + // TODO: May still refer to an alias of a global symbol using a single + // identifier in the namespace. + // No need to use iterator over namespaces if here + let mut scope = self.cur_scope.as_ref().unwrap(); + + loop { + debug_assert!(scope.is_block()); + let block = &ctx.heap[scope.to_block()]; + + for local_id in &block.locals { + let local = &ctx.heap[*local_id]; + + if local.relative_pos_in_block < relative_pos && local.identifier.value == identifier.value { + return Ok(local_id.upcast()); + } + } + + debug_assert!(block.parent_scope.is_some()); + scope = block.parent_scope.as_ref().unwrap(); + if !scope.is_block() { + // Definition scope, need to check arguments to definition + match scope { + Scope::Definition(definition_id) => { + let definition = &ctx.heap[*definition_id]; + for parameter_id in definition.parameters() { + let parameter = &ctx.heap[*parameter_id]; + if parameter.identifier.value == identifier.value { + return Ok(parameter_id.upcast()); + } + } + }, + _ => unreachable!(), + } + + // Variable could not be found + return Err(ParseError2::new_error( + &ctx.module.source, identifier.position, "This variable is not declared" + )); + } else { + relative_pos = block.relative_pos_in_parent; + } + } + } + + /// Adds a particular label to the current scope. Will return an error if + /// there is another label with the same name visible in the current scope. + fn checked_label_add(&mut self, ctx: &mut Ctx, id: LabeledStatementId) -> Result<(), ParseError2> { + debug_assert!(self.cur_scope.is_some()); + + // Make sure label is not defined within the current scope or any of the + // parent scope. + let label = &ctx.heap[id]; + let mut scope = self.cur_scope.as_ref().unwrap(); + + loop { + debug_assert!(scope.is_block(), "scope is not a block"); + let block = &ctx.heap[scope.to_block()]; + for other_label_id in &block.labels { + let other_label = &ctx.heap[*other_label_id]; + if other_label.label.value == label.label.value { + // Collision + return Err( + ParseError2::new_error(&ctx.module.source, label.position, "Label name conflicts with another label") + .with_postfixed_info(&ctx.module.source, other_label.position, "Other label is found here") + ); + } + } + + debug_assert!(block.parent_scope.is_some(), "block scope does not have a parent"); + scope = block.parent_scope.as_ref().unwrap(); + if !scope.is_block() { + break; + } + } + + // No collisions + let block = &mut ctx.heap[self.cur_scope.as_ref().unwrap().to_block()]; + block.labels.push(id); + + Ok(()) + } + + /// Finds a particular labeled statement by its identifier. Once found it + /// will make sure that the target label does not skip over any variable + /// declarations within the scope in which the label was found. + fn find_label(&self, ctx: &Ctx, identifier: &Identifier) -> Result { + debug_assert!(self.cur_scope.is_some()); + + let mut scope = self.cur_scope.as_ref().unwrap(); + loop { + debug_assert!(scope.is_block(), "scope is not a block"); + let relative_scope_pos = ctx.heap[scope.to_block()].relative_pos_in_parent; + + let block = &ctx.heap[scope.to_block()]; + for label_id in &block.labels { + let label = &ctx.heap[*label_id]; + if label.label.value == identifier.value { + for local_id in &block.locals { + // TODO: Better to do this in control flow analysis, it + // is legal to skip over a variable declaration if it + // is not actually being used. I might be missing + // something here when laying out the bytecode... + let local = &ctx.heap[*local_id]; + if local.relative_pos_in_block > relative_scope_pos && local.relative_pos_in_block < label.relative_pos_in_block { + return Err( + ParseError2::new_error(&ctx.module.source, identifier.position, "This target label skips over a variable declaration") + .with_postfixed_info(&ctx.module.source, label.position, "Because it jumps to this label") + .with_postfixed_info(&ctx.module.source, local.position, "Which skips over this variable") + ); + } + } + return Ok(*label_id); + } + } + + debug_assert!(block.parent_scope.is_some(), "block scope does not have a parent"); + scope = block.parent_scope.as_ref().unwrap(); + if !scope.is_block() { + return Err(ParseError2::new_error(&ctx.module.source, identifier.position, "Could not find this label")); + } + + } + } + + /// Finds a particular symbol in the symbol table which must correspond to + /// a definition of a particular type. + // Note: root_id, symbols and types passed in explicitly to prevent + // borrowing errors + fn find_symbol_of_type( + &self, root_id: RootId, symbols: &SymbolTable, types: &TypeTable, + identifier: &NamespacedIdentifier, expected_type_class: TypeClass + ) -> FindOfTypeResult { + // Find symbol associated with identifier + let symbol = symbols.resolve_namespaced_symbol(root_id, &identifier); + if symbol.is_none() { return FindOfTypeResult::NotFound; } + + let (symbol, iter) = symbol.unwrap(); + if iter.num_remaining() != 0 { return FindOfTypeResult::NotFound; } + + match &symbol.symbol { + Symbol::Definition((_, definition_id)) => { + // Make sure definition is of the expected type + let definition_type = types.get_definition(definition_id); + debug_assert!(definition_type.is_some(), "Found symbol '{}' in symbol table, but not in type table", String::from_utf8_lossy(&identifier.value)); + let definition_type_class = definition_type.unwrap().type_class(); + + if definition_type_class != expected_type_class { + FindOfTypeResult::TypeMismatch(definition_type_class.display_name()) + } else { + FindOfTypeResult::Found(*definition_id) + } + }, + Symbol::Namespace(_) => FindOfTypeResult::TypeMismatch("namespace"), + } + } + + /// This function will check if the provided while statement ID has a block + /// statement that is one of our current parents. + fn has_parent_while_scope(&self, ctx: &Ctx, id: WhileStatementId) -> bool { + debug_assert!(self.cur_scope.is_some()); + let mut scope = self.cur_scope.as_ref().unwrap(); + let while_stmt = &ctx.heap[id]; + loop { + debug_assert!(scope.is_block()); + let block = scope.to_block(); + if while_stmt.body == block.upcast() { + return true; + } + + let block = &ctx.heap[block]; + debug_assert!(block.parent_scope.is_some(), "block scope does not have a parent"); + scope = block.parent_scope.as_ref().unwrap(); + if !scope.is_block() { + return false; + } + } + } + + /// This function should be called while dealing with break/continue + /// statements. It will try to find the targeted while statement, using the + /// target label if provided. If a valid target is found then the loop's + /// ID will be returned, otherwise a parsing error is constructed. + /// The provided input position should be the position of the break/continue + /// statement. + fn resolve_break_or_continue_target(&self, ctx: &Ctx, position: InputPosition, label: &Option) -> Result { + let target = match label { + Some(label) => { + let target_id = self.find_label(ctx, label)?; + + // Make sure break target is a while statement + let target = &ctx.heap[target_id]; + if let Statement::While(target_stmt) = &ctx.heap[target.body] { + // Even though we have a target while statement, the break might not be + // present underneath this particular labeled while statement + if !self.has_parent_while_scope(ctx, target_stmt.this) { + ParseError2::new_error(&ctx.module.source, label.position, "Break statement is not nested under the target label's while statement") + .with_postfixed_info(&ctx.module.source, target.position, "The targeted label is found here"); + } + + target_stmt.this + } else { + return Err( + ParseError2::new_error(&ctx.module.source, label.position, "Incorrect break target label, it must target a while loop") + .with_postfixed_info(&ctx.module.source, target.position, "The targeted label is found here") + ); + } + }, + None => { + // Use the enclosing while statement, the break must be + // nested within that while statement + if self.in_while.is_none() { + return Err( + ParseError2::new_error(&ctx.module.source, position, "Break statement is not nested under a while loop") + ); + } + + self.in_while.unwrap() + } + }; + + // We have a valid target for the break statement. But we need to + // make sure we will not break out of a synchronous block + { + let target_while = &ctx.heap[target]; + if target_while.in_sync != self.in_sync { + // Break is nested under while statement, so can only escape a + // sync block if the sync is nested inside the while statement. + debug_assert!(self.in_sync.is_some()); + let sync_stmt = &ctx.heap[self.in_sync.unwrap()]; + return Err( + ParseError2::new_error(&ctx.module.source, position, "Break may not escape the surrounding synchronous block") + .with_postfixed_info(&ctx.module.source, target_while.position, "The break escapes out of this loop") + .with_postfixed_info(&ctx.module.source, sync_stmt.position, "And would therefore escape this synchronous block") + ); + } + } + + Ok(target) + } +} \ No newline at end of file