diff --git a/src/protocol/ast.rs b/src/protocol/ast.rs index 808bf838ab760e31342e87290ec309670474fc37..4f1d68b6934fab0b8de0c1d17c3b44df8a20b741 100644 --- a/src/protocol/ast.rs +++ b/src/protocol/ast.rs @@ -1,3 +1,6 @@ +// TODO: @cleanup, rigorous cleanup of dead code and silly object-oriented +// trait impls where I deem them unfit. + use std::fmt; use std::fmt::{Debug, Display, Formatter}; use std::ops::{Index, IndexMut}; @@ -11,14 +14,32 @@ use crate::protocol::inputsource::*; /// Helper macro that defines a type alias for a AST element ID. In this case /// only used to alias the `Id` types. macro_rules! define_aliased_ast_id { + // Variant where we just defined the alias, without any indexing ($name:ident, $parent:ty) => { pub type $name = $parent; + }; + // Variant where we define the type, and the Index and IndexMut traits + ($name:ident, $parent:ty, $indexed_type:ty, $indexed_arena:ident) => { + define_aliased_ast_id!($name, $parent); + impl Index<$name> for Heap { + type Output = $indexed_type; + fn index(&self, index: $name) -> &Self::Output { + &self.$indexed_arena[index] + } + } + + impl IndexMut<$name> for Heap { + fn index_mut(&mut self, index: $name) -> &mut Self::Output { + &mut self.$indexed_arena[index] + } + } } } -/// Helper macro that defines a subtype for a particular variant of an AST -/// element ID. +/// Helper macro that defines a wrapper type for a particular variant of an AST +/// element ID. Only used to define single-wrapping IDs. macro_rules! define_new_ast_id { + // Variant where we just defined the new type, without any indexing ($name:ident, $parent:ty) => { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)] pub struct $name (pub(crate) $parent); @@ -29,57 +50,81 @@ macro_rules! define_new_ast_id { } } }; + // Variant where we define the type, and the Index and IndexMut traits + ($name:ident, $parent:ty, $indexed_type:ty, $wrapper_type:path, $indexed_arena:ident) => { + define_new_ast_id!($name, $parent); + impl Index<$name> for Heap { + type Output = $indexed_type; + fn index(&self, index: $name) -> &Self::Output { + if let $wrapper_type(v) = &self.$indexed_arena[index.0] { + v + } else { + unreachable!() + } + } + } + + impl IndexMut<$name> for Heap { + fn index_mut(&mut self, index: $name) -> &mut Self::Output { + if let $wrapper_type(v) = &mut self.$indexed_arena[index.0] { + v + } else { + unreachable!() + } + } + } + } } -define_aliased_ast_id!(RootId, Id); -define_aliased_ast_id!(PragmaId, Id); -define_aliased_ast_id!(ImportId, Id); -define_aliased_ast_id!(ParserTypeId, Id); +define_aliased_ast_id!(RootId, Id, Root, protocol_descriptions); +define_aliased_ast_id!(PragmaId, Id, Pragma, pragmas); +define_aliased_ast_id!(ImportId, Id, Import, imports); +define_aliased_ast_id!(ParserTypeId, Id, ParserType, parser_types); -define_aliased_ast_id!(VariableId, Id); -define_new_ast_id!(ParameterId, VariableId); -define_new_ast_id!(LocalId, VariableId); +define_aliased_ast_id!(VariableId, Id, Variable, variables); +define_new_ast_id!(ParameterId, VariableId, Parameter, Variable::Parameter, variables); +define_new_ast_id!(LocalId, VariableId, Local, Variable::Local, variables); -define_aliased_ast_id!(DefinitionId, Id); -define_new_ast_id!(StructId, DefinitionId); -define_new_ast_id!(EnumId, DefinitionId); -define_new_ast_id!(ComponentId, DefinitionId); -define_new_ast_id!(FunctionId, DefinitionId); +define_aliased_ast_id!(DefinitionId, Id, Definition, definitions); +define_new_ast_id!(StructId, DefinitionId, StructDefinition, Definition::Struct, definitions); +define_new_ast_id!(EnumId, DefinitionId, EnumDefinition, Definition::Enum, definitions); +define_new_ast_id!(ComponentId, DefinitionId, Component, Definition::Component, definitions); +define_new_ast_id!(FunctionId, DefinitionId, Function, Definition::Function, definitions); -define_aliased_ast_id!(StatementId, Id); -define_new_ast_id!(BlockStatementId, StatementId); -define_new_ast_id!(LocalStatementId, StatementId); +define_aliased_ast_id!(StatementId, Id, Statement, statements); +define_new_ast_id!(BlockStatementId, StatementId, BlockStatement, Statement::Block, statements); +define_new_ast_id!(LocalStatementId, StatementId, LocalStatement, Statement::Local, statements); define_new_ast_id!(MemoryStatementId, LocalStatementId); define_new_ast_id!(ChannelStatementId, LocalStatementId); -define_new_ast_id!(SkipStatementId, StatementId); -define_new_ast_id!(LabeledStatementId, StatementId); -define_new_ast_id!(IfStatementId, StatementId); -define_new_ast_id!(EndIfStatementId, StatementId); -define_new_ast_id!(WhileStatementId, StatementId); -define_new_ast_id!(EndWhileStatementId, StatementId); -define_new_ast_id!(BreakStatementId, StatementId); -define_new_ast_id!(ContinueStatementId, StatementId); -define_new_ast_id!(SynchronousStatementId, StatementId); -define_new_ast_id!(EndSynchronousStatementId, StatementId); -define_new_ast_id!(ReturnStatementId, StatementId); -define_new_ast_id!(AssertStatementId, StatementId); -define_new_ast_id!(GotoStatementId, StatementId); -define_new_ast_id!(NewStatementId, StatementId); -define_new_ast_id!(PutStatementId, StatementId); -define_new_ast_id!(ExpressionStatementId, StatementId); - -define_aliased_ast_id!(ExpressionId, Id); -define_new_ast_id!(AssignmentExpressionId, ExpressionId); -define_new_ast_id!(ConditionalExpressionId, ExpressionId); -define_new_ast_id!(BinaryExpressionId, ExpressionId); -define_new_ast_id!(UnaryExpressionId, ExpressionId); -define_new_ast_id!(IndexingExpressionId, ExpressionId); -define_new_ast_id!(SlicingExpressionId, ExpressionId); -define_new_ast_id!(SelectExpressionId, ExpressionId); -define_new_ast_id!(ArrayExpressionId, ExpressionId); -define_new_ast_id!(ConstantExpressionId, ExpressionId); -define_new_ast_id!(CallExpressionId, ExpressionId); -define_new_ast_id!(VariableExpressionId, ExpressionId); +define_new_ast_id!(SkipStatementId, StatementId, SkipStatement, Statement::Skip, statements); +define_new_ast_id!(LabeledStatementId, StatementId, LabeledStatement, Statement::Labeled, statements); +define_new_ast_id!(IfStatementId, StatementId, IfStatement, Statement::If, statements); +define_new_ast_id!(EndIfStatementId, StatementId, EndIfStatement, Statement::EndIf, statements); +define_new_ast_id!(WhileStatementId, StatementId, WhileStatement, Statement::While, statements); +define_new_ast_id!(EndWhileStatementId, StatementId, EndWhileStatement, Statement::EndWhile, statements); +define_new_ast_id!(BreakStatementId, StatementId, BreakStatement, Statement::Break, statements); +define_new_ast_id!(ContinueStatementId, StatementId, ContinueStatement, Statement::Continue, statements); +define_new_ast_id!(SynchronousStatementId, StatementId, SynchronousStatement, Statement::Synchronous, statements); +define_new_ast_id!(EndSynchronousStatementId, StatementId, EndSynchronousStatement, Statement::EndSynchronous, statements); +define_new_ast_id!(ReturnStatementId, StatementId, ReturnStatement, Statement::Return, statements); +define_new_ast_id!(AssertStatementId, StatementId, AssertStatement, Statement::Assert, statements); +define_new_ast_id!(GotoStatementId, StatementId, GotoStatement, Statement::Goto, statements); +define_new_ast_id!(NewStatementId, StatementId, NewStatement, Statement::New, statements); +define_new_ast_id!(PutStatementId, StatementId, PutStatement, Statement::Put, statements); +define_new_ast_id!(ExpressionStatementId, StatementId, ExpressionStatement, Statement::Expression, statements); + +define_aliased_ast_id!(ExpressionId, Id, Expression, expressions); +define_new_ast_id!(AssignmentExpressionId, ExpressionId, AssignmentExpression, Expression::Assignment, expressions); +define_new_ast_id!(ConditionalExpressionId, ExpressionId, ConditionalExpression, Expression::Conditional, expressions); +define_new_ast_id!(BinaryExpressionId, ExpressionId, BinaryExpression, Expression::Binary, expressions); +define_new_ast_id!(UnaryExpressionId, ExpressionId, UnaryExpression, Expression::Unary, expressions); +define_new_ast_id!(IndexingExpressionId, ExpressionId, IndexingExpression, Expression::Indexing, expressions); +define_new_ast_id!(SlicingExpressionId, ExpressionId, SlicingExpression, Expression::Slicing, expressions); +define_new_ast_id!(SelectExpressionId, ExpressionId, SelectExpression, Expression::Select, expressions); +define_new_ast_id!(ArrayExpressionId, ExpressionId, ArrayExpression, Expression::Array, expressions); +define_new_ast_id!(ConstantExpressionId, ExpressionId, ConstantExpression, Expression::Constant, expressions); +define_new_ast_id!(CallExpressionId, ExpressionId, CallExpression, Expression::Call, expressions); +define_new_ast_id!(VariableExpressionId, ExpressionId, VariableExpression, Expression::Variable, expressions); // TODO: @cleanup - pub qualifiers can be removed once done #[derive(Debug, serde::Serialize, serde::Deserialize)] @@ -435,133 +480,6 @@ impl Heap { } } -impl Index for Heap { - type Output = Root; - fn index(&self, index: RootId) -> &Self::Output { - &self.protocol_descriptions[index] - } -} - -impl IndexMut for Heap { - fn index_mut(&mut self, index: RootId) -> &mut Self::Output { - &mut self.protocol_descriptions[index] - } -} - -impl Index for Heap { - type Output = Pragma; - fn index(&self, index: PragmaId) -> &Self::Output { - &self.pragmas[index] - } -} - -impl Index for Heap { - type Output = Import; - fn index(&self, index: ImportId) -> &Self::Output { - &self.imports[index] - } -} - -impl IndexMut for Heap { - fn index_mut(&mut self, index: ImportId) -> &mut Self::Output { - &mut self.imports[index] - } -} - -impl Index for Heap { - type Output = ParserType; - fn index(&self, index: ParserTypeId) -> &Self::Output { - &self.parser_types[index] - } -} - -impl IndexMut for Heap { - fn index_mut(&mut self, index: ParserTypeId) -> &mut Self::Output { - &mut self.parser_types[index] - } -} - -impl Index for Heap { - type Output = Variable; - fn index(&self, index: VariableId) -> &Self::Output { - &self.variables[index] - } -} - -impl Index for Heap { - type Output = Parameter; - fn index(&self, index: ParameterId) -> &Self::Output { - &self.variables[index.0].as_parameter() - } -} - -impl Index for Heap { - type Output = Local; - fn index(&self, index: LocalId) -> &Self::Output { - &self.variables[index.0].as_local() - } -} - -impl IndexMut for Heap { - fn index_mut(&mut self, index: LocalId) -> &mut Self::Output { - self.variables[index.0].as_local_mut() - } -} - -impl Index for Heap { - type Output = Definition; - fn index(&self, index: DefinitionId) -> &Self::Output { - &self.definitions[index] - } -} - -impl Index for Heap { - type Output = Component; - fn index(&self, index: ComponentId) -> &Self::Output { - &self.definitions[index.0].as_component() - } -} - -impl Index for Heap { - type Output = Function; - fn index(&self, index: FunctionId) -> &Self::Output { - &self.definitions[index.0].as_function() - } -} - -impl Index for Heap { - type Output = Statement; - fn index(&self, index: StatementId) -> &Self::Output { - &self.statements[index] - } -} - -impl IndexMut for Heap { - fn index_mut(&mut self, index: StatementId) -> &mut Self::Output { - &mut self.statements[index] - } -} - -impl Index for Heap { - type Output = BlockStatement; - fn index(&self, index: BlockStatementId) -> &Self::Output { - &self.statements[index.0].as_block() - } -} - -impl IndexMut for Heap { - fn index_mut(&mut self, index: BlockStatementId) -> &mut Self::Output { - (&mut self.statements[index.0]).as_block_mut() - } -} - -impl Index for Heap { - type Output = LocalStatement; - fn index(&self, index: LocalStatementId) -> &Self::Output { - &self.statements[index.0].as_local() - } -} - impl Index for Heap { type Output = MemoryStatement; fn index(&self, index: MemoryStatementId) -> &Self::Output { @@ -576,249 +494,6 @@ impl Index for Heap { } } -impl Index for Heap { - type Output = SkipStatement; - fn index(&self, index: SkipStatementId) -> &Self::Output { - &self.statements[index.0].as_skip() - } -} - -impl Index for Heap { - type Output = LabeledStatement; - fn index(&self, index: LabeledStatementId) -> &Self::Output { - &self.statements[index.0].as_labeled() - } -} - -impl IndexMut for Heap { - fn index_mut(&mut self, index: LabeledStatementId) -> &mut Self::Output { - (&mut self.statements[index.0]).as_labeled_mut() - } -} - -impl Index for Heap { - type Output = IfStatement; - fn index(&self, index: IfStatementId) -> &Self::Output { - &self.statements[index.0].as_if() - } -} - -impl IndexMut for Heap { - fn index_mut(&mut self, index: IfStatementId) -> &mut Self::Output { - self.statements[index.0].as_if_mut() - } -} - -impl Index for Heap { - type Output = EndIfStatement; - fn index(&self, index: EndIfStatementId) -> &Self::Output { - &self.statements[index.0].as_end_if() - } -} - -impl Index for Heap { - type Output = WhileStatement; - fn index(&self, index: WhileStatementId) -> &Self::Output { - &self.statements[index.0].as_while() - } -} - -impl IndexMut for Heap { - fn index_mut(&mut self, index: WhileStatementId) -> &mut Self::Output { - (&mut self.statements[index.0]).as_while_mut() - } -} - -impl Index for Heap { - type Output = BreakStatement; - fn index(&self, index: BreakStatementId) -> &Self::Output { - &self.statements[index.0].as_break() - } -} - -impl IndexMut for Heap { - fn index_mut(&mut self, index: BreakStatementId) -> &mut Self::Output { - (&mut self.statements[index.0]).as_break_mut() - } -} - -impl Index for Heap { - type Output = ContinueStatement; - fn index(&self, index: ContinueStatementId) -> &Self::Output { - &self.statements[index.0].as_continue() - } -} - -impl IndexMut for Heap { - fn index_mut(&mut self, index: ContinueStatementId) -> &mut Self::Output { - (&mut self.statements[index.0]).as_continue_mut() - } -} - -impl Index for Heap { - type Output = SynchronousStatement; - fn index(&self, index: SynchronousStatementId) -> &Self::Output { - &self.statements[index.0].as_synchronous() - } -} - -impl IndexMut for Heap { - fn index_mut(&mut self, index: SynchronousStatementId) -> &mut Self::Output { - (&mut self.statements[index.0]).as_synchronous_mut() - } -} - -impl Index for Heap { - type Output = EndSynchronousStatement; - fn index(&self, index: EndSynchronousStatementId) -> &Self::Output { - &self.statements[index.0].as_end_synchronous() - } -} - -impl Index for Heap { - type Output = ReturnStatement; - fn index(&self, index: ReturnStatementId) -> &Self::Output { - &self.statements[index.0].as_return() - } -} - -impl Index for Heap { - type Output = AssertStatement; - fn index(&self, index: AssertStatementId) -> &Self::Output { - &self.statements[index.0].as_assert() - } -} - -impl Index for Heap { - type Output = GotoStatement; - fn index(&self, index: GotoStatementId) -> &Self::Output { - &self.statements[index.0].as_goto() - } -} - -impl IndexMut for Heap { - fn index_mut(&mut self, index: GotoStatementId) -> &mut Self::Output { - (&mut self.statements[index.0]).as_goto_mut() - } -} - -impl Index for Heap { - type Output = NewStatement; - fn index(&self, index: NewStatementId) -> &Self::Output { - &self.statements[index.0].as_new() - } -} - -impl Index for Heap { - type Output = PutStatement; - fn index(&self, index: PutStatementId) -> &Self::Output { - &self.statements[index.0].as_put() - } -} - -impl Index for Heap { - type Output = ExpressionStatement; - fn index(&self, index: ExpressionStatementId) -> &Self::Output { - &self.statements[index.0].as_expression() - } -} - -impl Index for Heap { - type Output = Expression; - fn index(&self, index: ExpressionId) -> &Self::Output { - &self.expressions[index] - } -} - -impl Index for Heap { - type Output = AssignmentExpression; - fn index(&self, index: AssignmentExpressionId) -> &Self::Output { - &self.expressions[index.0].as_assignment() - } -} - -impl Index for Heap { - type Output = ConditionalExpression; - fn index(&self, index: ConditionalExpressionId) -> &Self::Output { - &self.expressions[index.0].as_conditional() - } -} - -impl Index for Heap { - type Output = BinaryExpression; - fn index(&self, index: BinaryExpressionId) -> &Self::Output { - &self.expressions[index.0].as_binary() - } -} - -impl Index for Heap { - type Output = UnaryExpression; - fn index(&self, index: UnaryExpressionId) -> &Self::Output { - &self.expressions[index.0].as_unary() - } -} - -impl Index for Heap { - type Output = IndexingExpression; - fn index(&self, index: IndexingExpressionId) -> &Self::Output { - &self.expressions[index.0].as_indexing() - } -} - -impl Index for Heap { - type Output = SlicingExpression; - fn index(&self, index: SlicingExpressionId) -> &Self::Output { - &self.expressions[index.0].as_slicing() - } -} - -impl Index for Heap { - type Output = SelectExpression; - fn index(&self, index: SelectExpressionId) -> &Self::Output { - &self.expressions[index.0].as_select() - } -} - -impl Index for Heap { - type Output = ArrayExpression; - fn index(&self, index: ArrayExpressionId) -> &Self::Output { - &self.expressions[index.0].as_array() - } -} - -impl Index for Heap { - type Output = ConstantExpression; - fn index(&self, index: ConstantExpressionId) -> &Self::Output { - &self.expressions[index.0].as_constant() - } -} - -impl Index for Heap { - type Output = CallExpression; - fn index(&self, index: CallExpressionId) -> &Self::Output { - &self.expressions[index.0].as_call() - } -} - -impl IndexMut for Heap { - fn index_mut(&mut self, index: CallExpressionId) -> &mut Self::Output { - (&mut self.expressions[index.0]).as_call_mut() - } -} - -impl Index for Heap { - type Output = VariableExpression; - fn index(&self, index: VariableExpressionId) -> &Self::Output { - &self.expressions[index.0].as_variable() - } -} - -impl IndexMut for Heap { - fn index_mut(&mut self, index: VariableExpressionId) -> &mut Self::Output { - (&mut self.expressions[index.0]).as_variable_mut() - } -} - #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct Root { pub this: RootId, @@ -2233,6 +1908,19 @@ impl SyntaxElement for ExpressionStatement { } } +#[derive(Debug, PartialEq, Eq, Clone, Copy, serde::Serialize, serde::Deserialize)] +pub enum ExpressionParent { + None, // only set during initial parsing + Memory(MemoryStatementId), + If(IfStatementId), + While(WhileStatementId), + Return(ReturnStatementId), + Assert(AssertStatementId), + Put(PutStatementId, u32), // index of arg + ExpressionStmt(ExpressionStatementId), + Expression(ExpressionId, u32) // index within expression (e.g LHS or RHS of expression) +} + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub enum Expression { Assignment(AssignmentExpression), @@ -2327,6 +2015,37 @@ impl Expression { _ => panic!("Unable to cast `Expression` to `VariableExpression`"), } } + // TODO: @cleanup + pub fn parent(&self) -> &ExpressionParent { + match self { + Expression::Assignment(expr) => &expr.parent, + Expression::Conditional(expr) => &expr.parent, + Expression::Binary(expr) => &expr.parent, + Expression::Unary(expr) => &expr.parent, + Expression::Indexing(expr) => &expr.parent, + Expression::Slicing(expr) => &expr.parent, + Expression::Select(expr) => &expr.parent, + Expression::Array(expr) => &expr.parent, + Expression::Constant(expr) => &expr.parent, + Expression::Call(expr) => &expr.parent, + Expression::Variable(expr) => &expr.parent, + } + } + pub fn set_parent(&mut self, parent: ExpressionParent) { + match self { + Expression::Assignment(expr) => expr.parent = parent, + Expression::Conditional(expr) => expr.parent = parent, + Expression::Binary(expr) => expr.parent = parent, + Expression::Unary(expr) => expr.parent = parent, + Expression::Indexing(expr) => expr.parent = parent, + Expression::Slicing(expr) => expr.parent = parent, + Expression::Select(expr) => expr.parent = parent, + Expression::Array(expr) => expr.parent = parent, + Expression::Constant(expr) => expr.parent = parent, + Expression::Call(expr) => expr.parent = parent, + Expression::Variable(expr) => expr.parent = parent, + } + } } impl SyntaxElement for Expression { @@ -2370,6 +2089,8 @@ pub struct AssignmentExpression { pub left: ExpressionId, pub operation: AssignmentOperator, pub right: ExpressionId, + // Phase 2: linker + pub parent: ExpressionParent, } impl SyntaxElement for AssignmentExpression { @@ -2386,6 +2107,8 @@ pub struct ConditionalExpression { pub test: ExpressionId, pub true_expression: ExpressionId, pub false_expression: ExpressionId, + // Phase 2: linker + pub parent: ExpressionParent, } impl SyntaxElement for ConditionalExpression { @@ -2425,6 +2148,8 @@ pub struct BinaryExpression { pub left: ExpressionId, pub operation: BinaryOperator, pub right: ExpressionId, + // Phase 2: linker + pub parent: ExpressionParent, } impl SyntaxElement for BinaryExpression { @@ -2452,6 +2177,8 @@ pub struct UnaryExpression { pub position: InputPosition, pub operation: UnaryOperation, pub expression: ExpressionId, + // Phase 2: linker + pub parent: ExpressionParent, } impl SyntaxElement for UnaryExpression { @@ -2467,6 +2194,8 @@ pub struct IndexingExpression { pub position: InputPosition, pub subject: ExpressionId, pub index: ExpressionId, + // Phase 2: linker + pub parent: ExpressionParent, } impl SyntaxElement for IndexingExpression { @@ -2483,6 +2212,8 @@ pub struct SlicingExpression { pub subject: ExpressionId, pub from_index: ExpressionId, pub to_index: ExpressionId, + // Phase 2: linker + pub parent: ExpressionParent, } impl SyntaxElement for SlicingExpression { @@ -2498,6 +2229,8 @@ pub struct SelectExpression { pub position: InputPosition, pub subject: ExpressionId, pub field: Field, + // Phase 2: linker + pub parent: ExpressionParent, } impl SyntaxElement for SelectExpression { @@ -2512,6 +2245,8 @@ pub struct ArrayExpression { // Phase 1: parser pub position: InputPosition, pub elements: Vec, + // Phase 2: linker + pub parent: ExpressionParent, } impl SyntaxElement for ArrayExpression { @@ -2527,6 +2262,8 @@ pub struct CallExpression { pub position: InputPosition, pub method: Method, pub arguments: Vec, + // Phase 2: linker + pub parent: ExpressionParent, } impl SyntaxElement for CallExpression { @@ -2541,6 +2278,8 @@ pub struct ConstantExpression { // Phase 1: parser pub position: InputPosition, pub value: Constant, + // Phase 2: linker + pub parent: ExpressionParent, } impl SyntaxElement for ConstantExpression { @@ -2557,6 +2296,7 @@ pub struct VariableExpression { pub identifier: NamespacedIdentifier, // Phase 2: linker pub declaration: Option, + pub parent: ExpressionParent, } impl SyntaxElement for VariableExpression { diff --git a/src/protocol/ast_printer.rs b/src/protocol/ast_printer.rs index 74e40d627796dbb44bf0397b49547a060bc1564c..917bc6318c13ae40e61c36683305cc246e4bc85a 100644 --- a/src/protocol/ast_printer.rs +++ b/src/protocol/ast_printer.rs @@ -489,6 +489,8 @@ impl ASTWriter { self.write_expr(heap, expr.left, indent3); self.kv(indent2).with_s_key("Right"); self.write_expr(heap, expr.right, indent3); + self.kv(indent2).with_s_key("Parent") + .with_custom_val(|v| write_expression_parent(v, &expr.parent)); }, Expression::Conditional(expr) => { self.kv(indent).with_id(PREFIX_CONDITIONAL_EXPR_ID, expr.this.0.index) @@ -499,6 +501,8 @@ impl ASTWriter { self.write_expr(heap, expr.true_expression, indent3); self.kv(indent2).with_s_key("FalseExpression"); self.write_expr(heap, expr.false_expression, indent3); + self.kv(indent2).with_s_key("Parent") + .with_custom_val(|v| write_expression_parent(v, &expr.parent)); }, Expression::Binary(expr) => { self.kv(indent).with_id(PREFIX_BINARY_EXPR_ID, expr.this.0.index) @@ -508,6 +512,8 @@ impl ASTWriter { self.write_expr(heap, expr.left, indent3); self.kv(indent2).with_s_key("Right"); self.write_expr(heap, expr.right, indent3); + self.kv(indent2).with_s_key("Parent") + .with_custom_val(|v| write_expression_parent(v, &expr.parent)); }, Expression::Unary(expr) => { self.kv(indent).with_id(PREFIX_UNARY_EXPR_ID, expr.this.0.index) @@ -515,6 +521,8 @@ impl ASTWriter { self.kv(indent2).with_s_key("Operation").with_debug_val(&expr.operation); self.kv(indent2).with_s_key("Argument"); self.write_expr(heap, expr.expression, indent3); + self.kv(indent2).with_s_key("Parent") + .with_custom_val(|v| write_expression_parent(v, &expr.parent)); }, Expression::Indexing(expr) => { self.kv(indent).with_id(PREFIX_INDEXING_EXPR_ID, expr.this.0.index) @@ -523,6 +531,8 @@ impl ASTWriter { self.write_expr(heap, expr.subject, indent3); self.kv(indent2).with_s_key("Index"); self.write_expr(heap, expr.index, indent3); + self.kv(indent2).with_s_key("Parent") + .with_custom_val(|v| write_expression_parent(v, &expr.parent)); }, Expression::Slicing(expr) => { self.kv(indent).with_id(PREFIX_SLICING_EXPR_ID, expr.this.0.index) @@ -533,6 +543,8 @@ impl ASTWriter { self.write_expr(heap, expr.from_index, indent3); self.kv(indent2).with_s_key("ToIndex"); self.write_expr(heap, expr.to_index, indent3); + self.kv(indent2).with_s_key("Parent") + .with_custom_val(|v| write_expression_parent(v, &expr.parent)); }, Expression::Select(expr) => { self.kv(indent).with_id(PREFIX_SELECT_EXPR_ID, expr.this.0.index) @@ -548,6 +560,8 @@ impl ASTWriter { self.kv(indent2).with_s_key("Field").with_ascii_val(&field.value); } } + self.kv(indent2).with_s_key("Parent") + .with_custom_val(|v| write_expression_parent(v, &expr.parent)); }, Expression::Array(expr) => { self.kv(indent).with_id(PREFIX_ARRAY_EXPR_ID, expr.this.0.index) @@ -556,6 +570,9 @@ impl ASTWriter { for expr_id in &expr.elements { self.write_expr(heap, *expr_id, indent3); } + + self.kv(indent2).with_s_key("Parent") + .with_custom_val(|v| write_expression_parent(v, &expr.parent)); }, Expression::Constant(expr) => { self.kv(indent).with_id(PREFIX_CONST_EXPR_ID, expr.this.0.index) @@ -569,6 +586,9 @@ impl ASTWriter { Constant::Character(char) => { val.with_ascii_val(char); }, Constant::Integer(int) => { val.with_disp_val(int); }, } + + self.kv(indent2).with_s_key("Parent") + .with_custom_val(|v| write_expression_parent(v, &expr.parent)); }, Expression::Call(expr) => { self.kv(indent).with_id(PREFIX_CALL_EXPR_ID, expr.this.0.index) @@ -593,6 +613,10 @@ impl ASTWriter { for arg_id in &expr.arguments { self.write_expr(heap, *arg_id, indent3); } + + // Parent + self.kv(indent2).with_s_key("Parent") + .with_custom_val(|v| write_expression_parent(v, &expr.parent)); }, Expression::Variable(expr) => { self.kv(indent).with_id(PREFIX_VARIABLE_EXPR_ID, expr.this.0.index) @@ -600,6 +624,8 @@ impl ASTWriter { self.kv(indent2).with_s_key("Name").with_ascii_val(&expr.identifier.value); self.kv(indent2).with_s_key("Definition") .with_opt_disp_val(expr.declaration.as_ref().map(|v| &v.index)); + self.kv(indent2).with_s_key("Parent") + .with_custom_val(|v| write_expression_parent(v, &expr.parent)); } } } @@ -669,4 +695,20 @@ fn write_type(target: &mut String, heap: &Heap, t: &ParserType) { } target.write_str(">"); } +} + +fn write_expression_parent(target: &mut String, parent: &ExpressionParent) { + use ExpressionParent as EP; + + *target = match parent { + EP::None => String::from("None"), + EP::Memory(id) => format!("MemoryStmt({})", id.0.0.index), + EP::If(id) => format!("IfStmt({})", id.0.index), + EP::While(id) => format!("WhileStmt({})", id.0.index), + EP::Return(id) => format!("ReturnStmt({})", id.0.index), + EP::Assert(id) => format!("AssertStmt({})", id.0.index), + EP::Put(id, idx) => format!("PutStmt({}, {})", id.0.index, idx), + EP::ExpressionStmt(id) => format!("ExprStmt({})", id.0.index), + EP::Expression(id, idx) => format!("Expr({}, {})", id.index, idx) + }; } \ No newline at end of file diff --git a/src/protocol/lexer.rs b/src/protocol/lexer.rs index 5a87f118da17ea7f45d6a5fb4169057a966b923c..d8b00c04276d90cede9bb56f99a6945cb5db1872 100644 --- a/src/protocol/lexer.rs +++ b/src/protocol/lexer.rs @@ -743,6 +743,7 @@ impl Lexer<'_> { left, operation, right, + parent: ExpressionParent::None, }) .upcast()) } else { @@ -819,6 +820,7 @@ impl Lexer<'_> { test, true_expression, false_expression, + parent: ExpressionParent::None, }) .upcast()) } else { @@ -843,6 +845,7 @@ impl Lexer<'_> { left, operation, right, + parent: ExpressionParent::None, }) .upcast(); } @@ -866,6 +869,7 @@ impl Lexer<'_> { left, operation, right, + parent: ExpressionParent::None, }) .upcast(); } @@ -889,6 +893,7 @@ impl Lexer<'_> { left, operation, right, + parent: ExpressionParent::None, }) .upcast(); } @@ -912,6 +917,7 @@ impl Lexer<'_> { left, operation, right, + parent: ExpressionParent::None, }) .upcast(); } @@ -935,6 +941,7 @@ impl Lexer<'_> { left, operation, right, + parent: ExpressionParent::None, }) .upcast(); } @@ -958,6 +965,7 @@ impl Lexer<'_> { left, operation, right, + parent: ExpressionParent::None, }) .upcast(); } @@ -987,6 +995,7 @@ impl Lexer<'_> { left, operation, right, + parent: ExpressionParent::None, }) .upcast(); } @@ -1026,6 +1035,7 @@ impl Lexer<'_> { left, operation, right, + parent: ExpressionParent::None, }) .upcast(); } @@ -1057,6 +1067,7 @@ impl Lexer<'_> { left, operation, right, + parent: ExpressionParent::None, }) .upcast(); } @@ -1088,6 +1099,7 @@ impl Lexer<'_> { left, operation, right, + parent: ExpressionParent::None, }) .upcast(); } @@ -1123,6 +1135,7 @@ impl Lexer<'_> { left, operation, right, + parent: ExpressionParent::None, }) .upcast(); } @@ -1173,6 +1186,7 @@ impl Lexer<'_> { position, operation, expression, + parent: ExpressionParent::None, }) .upcast()); } @@ -1198,6 +1212,7 @@ impl Lexer<'_> { position, operation, expression, + parent: ExpressionParent::None, }) .upcast(); } else if self.has_string(b"--") { @@ -1211,6 +1226,7 @@ impl Lexer<'_> { position, operation, expression, + parent: ExpressionParent::None, }) .upcast(); } else if self.has_string(b"[") { @@ -1236,6 +1252,7 @@ impl Lexer<'_> { subject, from_index: index, to_index, + parent: ExpressionParent::None, }) .upcast(); } else { @@ -1245,6 +1262,7 @@ impl Lexer<'_> { position, subject, index, + parent: ExpressionParent::None, }) .upcast(); } @@ -1268,6 +1286,7 @@ impl Lexer<'_> { position, subject, field, + parent: ExpressionParent::None, }) .upcast(); } @@ -1310,7 +1329,12 @@ impl Lexer<'_> { } } self.consume_string(b"}")?; - Ok(h.alloc_array_expression(|this| ArrayExpression { this, position, elements })) + Ok(h.alloc_array_expression(|this| ArrayExpression { + this, + position, + elements, + parent: ExpressionParent::None, + })) } fn has_constant(&self) -> bool { is_constant(self.source.next()) @@ -1351,7 +1375,12 @@ impl Lexer<'_> { value = Constant::Integer(self.consume_integer()?); } - Ok(h.alloc_constant_expression(|this| ConstantExpression { this, position, value })) + Ok(h.alloc_constant_expression(|this| ConstantExpression { + this, + position, + value, + parent: ExpressionParent::None, + })) } fn has_call_expression(&mut self) -> bool { /* We prevent ambiguity with variables, by looking ahead @@ -1413,7 +1442,8 @@ impl Lexer<'_> { this, position, method, - arguments + arguments, + parent: ExpressionParent::None, })) } fn consume_variable_expression( @@ -1427,6 +1457,7 @@ impl Lexer<'_> { position, identifier, declaration: None, + parent: ExpressionParent::None, })) } diff --git a/src/protocol/parser/type_resolver.rs b/src/protocol/parser/type_resolver.rs index 78f464e91809554092bfefc13e2a24bd77614ee5..c5eb2f784a1e9e87c18869336207ee45d5781c87 100644 --- a/src/protocol/parser/type_resolver.rs +++ b/src/protocol/parser/type_resolver.rs @@ -1,4 +1,5 @@ use crate::protocol::ast::*; +use super::type_table::{ConcreteType, ConcreteTypeVariant}; use super::visitor::{ STMT_BUFFER_INIT_CAPACITY, EXPR_BUFFER_INIT_CAPACITY, @@ -6,6 +7,7 @@ use super::visitor::{ Visitor2, VisitorResult }; +use std::collections::HashMap; enum ExprType { Regular, // expression statement or return statement @@ -14,6 +16,17 @@ enum ExprType { Assert, // assert statement } +// TODO: @cleanup I will do a very dirty implementation first, because I have no idea +// what I am doing. +// Very rough idea: +// - go through entire AST first, find all places where we have inferred types +// (which may be embedded) and store them in some kind of map. +// - go through entire AST and visit all expressions depth-first. We will +// attempt to resolve the return type of each expression. If we can't then +// we store them in another lookup map and link the dependency on an +// inferred variable to that expression. +// - keep iterating until we have completely resolved all variables. + /// This particular visitor will recurse depth-first into the AST and ensures /// that all expressions have the appropriate types. At the moment this implies: /// @@ -30,6 +43,10 @@ pub(crate) struct TypeResolvingVisitor { // Buffers for iteration over substatements and subexpressions stmt_buffer: Vec, expr_buffer: Vec, + + // Map for associating "auto" variable with a concrete type where it is not + // yet determined. + env: HashMap } impl TypeResolvingVisitor { @@ -38,6 +55,7 @@ impl TypeResolvingVisitor { expr_type: ExprType::Regular, stmt_buffer: Vec::with_capacity(STMT_BUFFER_INIT_CAPACITY), expr_buffer: Vec::with_capacity(EXPR_BUFFER_INIT_CAPACITY), + env: HashMap::new(), } } @@ -56,6 +74,7 @@ impl Visitor2 for TypeResolvingVisitor { fn visit_function_definition(&mut self, ctx: &mut Ctx, id: FunctionId) -> VisitorResult { let body_stmt_id = ctx.heap[id].body; + self.visit_stmt(ctx, body_stmt_id) } @@ -83,26 +102,16 @@ impl Visitor2 for TypeResolvingVisitor { fn visit_local_memory_stmt(&mut self, ctx: &mut Ctx, id: MemoryStatementId) -> VisitorResult { let memory_stmt = &ctx.heap[id]; - // Type of local should match the type of the initial expression - // For now, with all variables having an explicit type, it seems we - // do not need to consider all expressions within a single definition in - // order to do typechecking and type inference for numeric constants. - - // Since each expression will get an assigned type, and everything - // already has a type, we may traverse leaf-to-root, while assigning - // output types. We throw an error if the types do not match. If an - // expression's type is already assigned then they should match. - - // Note that for numerical types the information may also travel upward. - // That is: if we have: - // - // u32 a = 5 * 2 << 3 + 8 - // - // Then we may infer that the expression yields a u32 type. As a result - // All of the literals 5, 2, 3 and 8 will have type u32 as well. + Ok(()) + } + fn visit_local_channel_stmt(&mut self, ctx: &mut Ctx, id: ChannelStatementId) -> VisitorResult { Ok(()) } +} + +impl TypeResolvingVisitor { + } \ No newline at end of file diff --git a/src/protocol/parser/type_table.rs b/src/protocol/parser/type_table.rs index 4857ae973485b8e07f9d2c75a710e25b62aa5e5e..1941f8da2361f7f1c25b4fea9d22a1def81a13d4 100644 --- a/src/protocol/parser/type_table.rs +++ b/src/protocol/parser/type_table.rs @@ -228,109 +228,29 @@ impl TypeIterator { } } -// #[derive(PartialEq, Eq)] -// enum SpecifiedTypeVariant { -// // No subtypes -// Message, -// Bool, -// Byte, -// Short, -// Int, -// Long, -// String, -// // Always one subtype -// ArrayOf, -// InputOf, -// OutputOf, -// // Variable number of subtypes, depending on the polymorphic arguments on -// // the definition -// InstanceOf(DefinitionId, usize) -// } - -// #[derive(Eq)] -// struct SpecifiedType { -// /// Definition ID, may not be enough as the type may be polymorphic -// definition: DefinitionId, -// /// The polymorphic types for the definition. These are encoded in a list, -// /// which we interpret as the depth-first serialization of the type tree. -// poly_vars: Vec -// } -// -// impl PartialEq for SpecifiedType { -// fn eq(&self, other: &Self) -> bool { -// // Should point to same definition and have the same polyvars -// if self.definition.index != other.definition.index { return false; } -// if self.poly_vars.len() != other.poly_vars.len() { return false; } -// for (my_var, other_var) in self.poly_vars.iter().zip(other.poly_vars.iter()) { -// if my_var != other_var { return false; } -// } -// -// return true -// } -// } -// -// impl SpecifiedType { -// fn new_non_polymorph(definition: DefinitionId) -> Self { -// Self{ definition, poly_vars: Vec::new() } -// } -// -// fn new_polymorph(definition: DefinitionId, heap: &Heap, parser_type_id: ParserTypeId) -> Self { -// // Serialize into concrete types -// let mut poly_vars = Vec::new(); -// Self::construct_poly_vars(&mut poly_vars, heap, parser_type_id); -// Self{ definition, poly_vars } -// } -// -// fn construct_poly_vars(poly_vars: &mut Vec, heap: &Heap, parser_type_id: ParserTypeId) { -// // Depth-first construction of poly vars -// let parser_type = &heap[parser_type_id]; -// match &parser_type.variant { -// ParserTypeVariant::Message => { poly_vars.push(SpecifiedTypeVariant::Message); }, -// ParserTypeVariant::Bool => { poly_vars.push(SpecifiedTypeVariant::Bool); }, -// ParserTypeVariant::Byte => { poly_vars.push(SpecifiedTypeVariant::Byte); }, -// ParserTypeVariant::Short => { poly_vars.push(SpecifiedTypeVariant::Short); }, -// ParserTypeVariant::Int => { poly_vars.push(SpecifiedTypeVariant::Int); }, -// ParserTypeVariant::Long => { poly_vars.push(SpecifiedTypeVariant::Long); }, -// ParserTypeVariant::String => { poly_vars.push(SpecifiedTypeVariant::String); }, -// ParserTypeVariant::Array(subtype_id) => { -// poly_vars.push(SpecifiedTypeVariant::ArrayOf); -// Self::construct_poly_vars(poly_vars, heap, *subtype_id); -// }, -// ParserTypeVariant::Input(subtype_id) => { -// poly_vars.push(SpecifiedTypeVariant::InputOf); -// Self::construct_poly_vars(poly_vars, heap, *subtype_id); -// }, -// ParserTypeVariant::Output(subtype_id) => { -// poly_vars.push(SpecifiedTypeVariant::OutputOf); -// Self::construct_poly_vars(poly_vars, heap, *subtype_id); -// }, -// ParserTypeVariant::Symbolic(symbolic) => { -// let definition_id = match symbolic.variant { -// SymbolicParserTypeVariant::Definition(definition_id) => definition_id, -// SymbolicParserTypeVariant::PolyArg(_) => { -// // When construct entries in the type table, we no longer allow the -// // unspecified types in the AST, we expect them to be fully inferred. -// debug_assert!(false, "Encountered 'PolyArg' symbolic type. Expected fully inferred types"); -// unreachable!(); -// } -// }; -// -// poly_vars.push(SpecifiedTypeVariant::InstanceOf(definition_id, symbolic.poly_args.len())); -// for subtype_id in &symbolic.poly_args { -// Self::construct_poly_vars(poly_vars, heap, *subtype_id); -// } -// }, -// ParserTypeVariant::IntegerLiteral => { -// debug_assert!(false, "Encountered 'IntegerLiteral' symbolic type. Expected fully inferred types"); -// unreachable!(); -// }, -// ParserTypeVariant::Inferred => { -// debug_assert!(false, "Encountered 'Inferred' symbolic type. Expected fully inferred types"); -// unreachable!(); -// } -// } -// } -// } +pub(crate) enum ConcreteTypeVariant { + // No subtypes + Message, + Bool, + Byte, + Short, + Int, + Long, + String, + // One subtype + Array, + Slice, + Input, + Output, + // Multiple subtypes (definition of thing and number of poly args) + Instance(DefinitionId, usize) +} + +pub(crate) struct ConcreteType { + // serialized version (interpret as serialized depth-first tree, with + // variant indicating the number of children (subtypes)) + pub(crate) v: Vec +} /// Result from attempting to resolve a `ParserType` using the symbol table and /// the type table. diff --git a/src/protocol/parser/visitor_linker.rs b/src/protocol/parser/visitor_linker.rs index a3e460dfe95b2352b25dda07250c9618f1471ee5..296273f01537986e6cc0530bcb579177938f1581 100644 --- a/src/protocol/parser/visitor_linker.rs +++ b/src/protocol/parser/visitor_linker.rs @@ -1,3 +1,5 @@ +use std::mem::{replace, swap}; + use crate::protocol::ast::*; use crate::protocol::inputsource::*; use crate::protocol::parser::{symbol_table::*, type_table::*}; @@ -9,6 +11,7 @@ use super::visitor::{ Visitor2, VisitorResult }; +use crate::protocol::ast::ExpressionParent::ExpressionStmt; #[derive(PartialEq, Eq)] enum DefinitionType { @@ -51,6 +54,9 @@ pub(crate) struct ValidityAndLinkerVisitor { cur_scope: Option, def_type: DefinitionType, performing_breadth_pass: bool, + // Parent expression (the previous stmt/expression we visited that could be + // used as an expression parent) + expr_parent: ExpressionParent, // Keeping track of relative position in block in the breadth-first pass. // May not correspond to block.statement[index] if any statements are // inserted after the breadth-pass @@ -72,6 +78,7 @@ impl ValidityAndLinkerVisitor { in_sync: None, in_while: None, cur_scope: None, + expr_parent: ExpressionParent::None, def_type: DefinitionType::Primitive, performing_breadth_pass: false, relative_pos_in_block: 0, @@ -85,10 +92,12 @@ impl ValidityAndLinkerVisitor { self.in_sync = None; self.in_while = None; self.cur_scope = None; + self.expr_parent = ExpressionParent::None; self.def_type = DefinitionType::Primitive; self.relative_pos_in_block = 0; self.performing_breadth_pass = false; self.statement_buffer.clear(); + self.expression_buffer.clear(); self.insert_buffer.clear(); } } @@ -106,6 +115,7 @@ impl Visitor2 for ValidityAndLinkerVisitor { ComponentVariant::Composite => DefinitionType::Composite, }; self.cur_scope = Some(Scope::Definition(id.upcast())); + self.expr_parent = ExpressionParent::None; let body_id = ctx.heap[id].body; self.performing_breadth_pass = true; @@ -120,8 +130,9 @@ impl Visitor2 for ValidityAndLinkerVisitor { // Set internal statement indices self.def_type = DefinitionType::Function; self.cur_scope = Some(Scope::Definition(id.upcast())); - let body_id = ctx.heap[id].body; + self.expr_parent = ExpressionParent::None; + let body_id = ctx.heap[id].body; self.performing_breadth_pass = true; self.visit_stmt(ctx, body_id)?; self.performing_breadth_pass = false; @@ -141,7 +152,10 @@ impl Visitor2 for ValidityAndLinkerVisitor { let variable_id = ctx.heap[id].variable; self.checked_local_add(ctx, self.relative_pos_in_block, variable_id)?; } else { + debug_assert_eq!(self.expr_parent, ExpressionParent::None); + self.expr_parent = ExpressionParent::Memory(id); self.visit_expr(ctx, ctx.heap[id].initial)?; + self.expr_parent = ExpressionParent::None; } Ok(()) @@ -197,7 +211,12 @@ impl Visitor2 for ValidityAndLinkerVisitor { let stmt = &ctx.heap[id]; (stmt.test, stmt.true_body, stmt.false_body) }; + + debug_assert_eq!(self.expr_parent, ExpressionParent::None); + self.expr_parent = ExpressionParent::If(id); self.visit_expr(ctx, test_id)?; + self.expr_parent = ExpressionParent::None; + self.visit_stmt(ctx, true_id)?; self.visit_stmt(ctx, false_id)?; } @@ -227,7 +246,11 @@ impl Visitor2 for ValidityAndLinkerVisitor { (stmt.test, stmt.body) }; let old_while = self.in_while.replace(id); + debug_assert_eq!(self.expr_parent, ExpressionParent::None); + self.expr_parent = ExpressionParent::While(id); self.visit_expr(ctx, test_id)?; + self.expr_parent = ExpressionParent::None; + self.visit_stmt(ctx, body_id)?; self.in_while = old_while; } @@ -318,7 +341,10 @@ impl Visitor2 for ValidityAndLinkerVisitor { } } else { // If here then we are within a function + debug_assert_eq!(self.expr_parent, ExpressionParent::None); + self.expr_parent = ExpressionParent::Return(id); self.visit_expr(ctx, ctx.heap[id].expression)?; + self.expr_parent = ExpressionParent::None; } Ok(()) @@ -345,7 +371,10 @@ impl Visitor2 for ValidityAndLinkerVisitor { ); } } else { + debug_assert_eq!(self.expr_parent, ExpressionParent::None); + self.expr_parent = ExpressionParent::Assert(id); let expr_id = stmt.expression; + self.expr_parent = ExpressionParent::None; self.visit_expr(ctx, expr_id)?; } @@ -466,8 +495,13 @@ impl Visitor2 for ValidityAndLinkerVisitor { let put_stmt = &ctx.heap[id]; let port = put_stmt.port; let message = put_stmt.message; + + debug_assert_eq!(self.expr_parent, ExpressionParent::None); + self.expr_parent = ExpressionParent::Put(id, 0); self.visit_expr(ctx, port)?; + self.expr_parent = ExpressionParent::Put(id, 1); self.visit_expr(ctx, message)?; + self.expr_parent = ExpressionParent::None; } Ok(()) @@ -476,7 +510,11 @@ impl Visitor2 for ValidityAndLinkerVisitor { fn visit_expr_stmt(&mut self, ctx: &mut Ctx, id: ExpressionStatementId) -> VisitorResult { if !self.performing_breadth_pass { let expr_id = ctx.heap[id].expression; + + debug_assert_eq!(self.expr_parent, ExpressionParent::None); + self.expr_parent = ExpressionParent::ExpressionStmt(id); self.visit_expr(ctx, expr_id)?; + self.expr_parent = ExpressionParent::None; } Ok(()) @@ -489,23 +527,42 @@ impl Visitor2 for ValidityAndLinkerVisitor { fn visit_assignment_expr(&mut self, ctx: &mut Ctx, id: AssignmentExpressionId) -> VisitorResult { debug_assert!(!self.performing_breadth_pass); - let assignment_expr = &ctx.heap[id]; + + let upcast_id = id.upcast(); + let assignment_expr = &mut ctx.heap[id]; + let left_expr_id = assignment_expr.left; let right_expr_id = assignment_expr.right; + let old_expr_parent = self.expr_parent; + assignment_expr.parent = old_expr_parent; + + self.expr_parent = ExpressionParent::Expression(upcast_id, 0); self.visit_expr(ctx, left_expr_id)?; + self.expr_parent = ExpressionParent::Expression(upcast_id, 1); self.visit_expr(ctx, right_expr_id)?; + self.expr_parent = old_expr_parent; Ok(()) } fn visit_conditional_expr(&mut self, ctx: &mut Ctx, id: ConditionalExpressionId) -> VisitorResult { debug_assert!(!self.performing_breadth_pass); - let conditional_expr = &ctx.heap[id]; + let upcast_id = id.upcast(); + let conditional_expr = &mut ctx.heap[id]; + let test_expr_id = conditional_expr.test; let true_expr_id = conditional_expr.true_expression; let false_expr_id = conditional_expr.false_expression; + let old_expr_parent = self.expr_parent; + conditional_expr.parent = old_expr_parent; + + self.expr_parent = ExpressionParent::Expression(upcast_id, 0); self.visit_expr(ctx, test_expr_id)?; + self.expr_parent = ExpressionParent::Expression(upcast_id, 1); self.visit_expr(ctx, true_expr_id)?; + self.expr_parent = ExpressionParent::Expression(upcast_id, 2); self.visit_expr(ctx, false_expr_id)?; + self.expr_parent = old_expr_parent; + Ok(()) }