diff --git a/src/protocol/ast.rs b/src/protocol/ast.rs index 6b2eaef1f425c59ef163f966967fb4a41daff97b..26c59a07ff48d4f95eb99cf4139b2ff262ed38cb 100644 --- a/src/protocol/ast.rs +++ b/src/protocol/ast.rs @@ -737,6 +737,22 @@ pub struct Scope { pub next_unique_id_in_scope: i32, } +impl Scope { + pub(crate) fn new(id: ScopeId, association: ScopeAssociation) -> Self { + return Self{ + this: id, + parent: None, + nested: Vec::new(), + association, + variables: Vec::new(), + labels: Vec::new(), + relative_pos_in_parent: -1, + first_unique_id_in_scope: -1, + next_unique_id_in_scope: -1, + } + } +} + impl Scope { pub(crate) fn new_invalid(this: ScopeId) -> Self { return Self{ diff --git a/src/protocol/parser/pass_definitions.rs b/src/protocol/parser/pass_definitions.rs index 3b0d9abe5004eb827700c49d090cd32e560feba1..383d76f7d9ac00efcd9ea11af1bff108e3bcb3cb 100644 --- a/src/protocol/parser/pass_definitions.rs +++ b/src/protocol/parser/pass_definitions.rs @@ -339,29 +339,9 @@ impl PassDefinitions { // Consume if statement and place end-if statement directly // after it. let id = self.consume_if_statement(module, iter, ctx)?; - - let end_if = ctx.heap.alloc_end_if_statement(|this| EndIfStatement { - this, - start_if: id, - next: StatementId::new_invalid() - }); - - let if_stmt = &mut ctx.heap[id]; - if_stmt.end_if = end_if; - return Ok(id.upcast()); } else if ident == KW_STMT_WHILE { let id = self.consume_while_statement(module, iter, ctx)?; - - let end_while = ctx.heap.alloc_end_while_statement(|this| EndWhileStatement { - this, - start_while: id, - next: StatementId::new_invalid() - }); - - let while_stmt = &mut ctx.heap[id]; - while_stmt.end_while = end_while; - return Ok(id.upcast()); } else if ident == KW_STMT_BREAK { let id = self.consume_break_statement(module, iter, ctx)?; @@ -371,16 +351,6 @@ impl PassDefinitions { return Ok(id.upcast()); } else if ident == KW_STMT_SYNC { let id = self.consume_synchronous_statement(module, iter, ctx)?; - - let end_sync = ctx.heap.alloc_end_synchronous_statement(|this| EndSynchronousStatement { - this, - start_sync: id, - next: StatementId::new_invalid() - }); - - let sync_stmt = &mut ctx.heap[id]; - sync_stmt.end_sync = end_sync; - return Ok(id.upcast()); } else if ident == KW_STMT_FORK { let id = self.consume_fork_statement(module, iter, ctx)?; @@ -397,16 +367,6 @@ impl PassDefinitions { return Ok(id.upcast()); } else if ident == KW_STMT_SELECT { let id = self.consume_select_statement(module, iter, ctx)?; - - let end_select = ctx.heap.alloc_end_select_statement(|this| EndSelectStatement{ - this, - start_select: id, - next: StatementId::new_invalid(), - }); - - let select_stmt = &mut ctx.heap[id]; - select_stmt.end_select = end_select; - return Ok(id.upcast()); } else if ident == KW_STMT_RETURN { let id = self.consume_return_statement(module, iter, ctx)?; @@ -472,7 +432,7 @@ impl PassDefinitions { let mut block_span = consume_token(&module.source, iter, TokenKind::CloseCurly)?; block_span.begin = open_curly_span.begin; - let id = ctx.heap.alloc_block_statement(|this| BlockStatement{ + let block_id = ctx.heap.alloc_block_statement(|this| BlockStatement{ this, is_implicit: false, span: block_span, @@ -481,15 +441,17 @@ impl PassDefinitions { scope: ScopeId::new_invalid(), next: StatementId::new_invalid(), }); + let scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::Block(block_id))); - let end_block = ctx.heap.alloc_end_block_statement(|this| EndBlockStatement{ - this, start_block: id, next: StatementId::new_invalid() + let end_block_id = ctx.heap.alloc_end_block_statement(|this| EndBlockStatement{ + this, start_block: block_id, next: StatementId::new_invalid() }); - let block_stmt = &mut ctx.heap[id]; - block_stmt.end_block = end_block; + let block_stmt = &mut ctx.heap[block_id]; + block_stmt.end_block = end_block_id; + block_stmt.scope = scope_id; - Ok(id) + Ok(block_id) } fn consume_if_statement( @@ -500,11 +462,11 @@ impl PassDefinitions { let test = self.consume_expression(module, iter, ctx)?; consume_token(&module.source, iter, TokenKind::CloseParen)?; + // Consume bodies of if-statement let true_body = IfStatementCase{ body: self.consume_statement(module, iter, ctx)?, scope: ScopeId::new_invalid(), }; - let true_body_scope_id = true_body.scope; let (false_body, false_body_scope_id) = if has_ident(&module.source, iter, KW_STMT_ELSE) { iter.consume(); @@ -519,6 +481,7 @@ impl PassDefinitions { (None, None) }; + // Construct AST elements let if_stmt_id = ctx.heap.alloc_if_statement(|this| IfStatement{ this, span: if_span, @@ -527,6 +490,20 @@ impl PassDefinitions { false_case: false_body, end_if: EndIfStatementId::new_invalid(), }); + let end_if_stmt_id = ctx.heap.alloc_end_if_statement(|this| EndIfStatement{ + this, + start_if: if_stmt_id, + next: StatementId::new_invalid(), + }); + let true_scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::If(if_stmt_id, true))); + + let if_stmt = &mut ctx.heap[if_stmt_id]; + if_stmt.end_if = end_if_stmt_id; + if_stmt.true_case.scope = true_scope_id; + if let Some(false_case) = &mut if_stmt.false_case { + let false_scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::If(if_stmt_id, false))); + false_case.scope = false_scope_id + } return Ok(if_stmt_id); } @@ -540,7 +517,7 @@ impl PassDefinitions { consume_token(&module.source, iter, TokenKind::CloseParen)?; let body = self.consume_statement(module, iter, ctx)?; - Ok(ctx.heap.alloc_while_statement(|this| WhileStatement{ + let while_stmt_id = ctx.heap.alloc_while_statement(|this| WhileStatement{ this, span: while_span, test, @@ -548,7 +525,19 @@ impl PassDefinitions { body, end_while: EndWhileStatementId::new_invalid(), in_sync: SynchronousStatementId::new_invalid(), - })) + }); + let end_while_stmt_id = ctx.heap.alloc_end_while_statement(|this| EndWhileStatement{ + this, + start_while: while_stmt_id, + next: StatementId::new_invalid(), + }); + let scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::While(while_stmt_id))); + + let while_stmt = &mut ctx.heap[while_stmt_id]; + while_stmt.scope = scope_id; + while_stmt.end_while = end_while_stmt_id; + + Ok(while_stmt_id) } fn consume_break_statement( @@ -595,13 +584,25 @@ impl PassDefinitions { let synchronous_span = consume_exact_ident(&module.source, iter, KW_STMT_SYNC)?; let body = self.consume_statement(module, iter, ctx)?; - Ok(ctx.heap.alloc_synchronous_statement(|this| SynchronousStatement{ + let sync_stmt_id = ctx.heap.alloc_synchronous_statement(|this| SynchronousStatement{ this, span: synchronous_span, scope: ScopeId::new_invalid(), body, end_sync: EndSynchronousStatementId::new_invalid(), - })) + }); + let end_sync_stmt_id = ctx.heap.alloc_end_synchronous_statement(|this| EndSynchronousStatement{ + this, + start_sync: sync_stmt_id, + next: StatementId::new_invalid(), + }); + let scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::Synchronous(sync_stmt_id))); + + let sync_stmt = &mut ctx.heap[sync_stmt_id]; + sync_stmt.scope = scope_id; + sync_stmt.end_sync = end_sync_stmt_id; + + return Ok(sync_stmt_id); } fn consume_fork_statement( @@ -668,12 +669,30 @@ impl PassDefinitions { consume_token(&module.source, iter, TokenKind::CloseCurly)?; - Ok(ctx.heap.alloc_select_statement(|this| SelectStatement{ + let num_cases = cases.len(); + let select_stmt_id = ctx.heap.alloc_select_statement(|this| SelectStatement{ this, span: select_span, cases, end_select: EndSelectStatementId::new_invalid(), - })) + }); + let end_select_stmt_id = ctx.heap.alloc_end_select_statement(|this| EndSelectStatement{ + this, + start_select: select_stmt_id, + next: StatementId::new_invalid(), + }); + + let select_stmt = &mut ctx.heap[select_stmt_id]; + select_stmt.end_select = end_select_stmt_id; + + for case_index in 0..num_cases { + let scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::SelectCase(select_stmt_id, case_index as u32))); + let select_stmt = &mut ctx.heap[select_stmt_id]; + let select_case = &mut select_stmt.cases[case_index]; + select_case.scope = scope_id; + } + + return Ok(select_stmt_id) } fn consume_return_statement( diff --git a/src/protocol/parser/pass_validation_linking.rs b/src/protocol/parser/pass_validation_linking.rs index 954bcc968214684eadc3f67b972d00aa463b067a..50501f7f680a8efd37dafad58ec0a94976dc3a29 100644 --- a/src/protocol/parser/pass_validation_linking.rs +++ b/src/protocol/parser/pass_validation_linking.rs @@ -209,14 +209,16 @@ impl Visitor for PassValidationLinking { fn visit_component_definition(&mut self, ctx: &mut Ctx, id: ComponentDefinitionId) -> VisitorResult { self.reset_state(); - self.def_type = match &ctx.heap[id].variant { + let definition = &ctx.heap[id]; + self.def_type = match &definition.variant { ComponentVariant::Primitive => DefinitionType::Primitive(id), ComponentVariant::Composite => DefinitionType::Composite(id), }; self.expr_parent = ExpressionParent::None; // Visit parameters and assign a unique scope ID - let old_scope = self.push_scope(ctx, ScopeAssociation::Definition(id.upcast())); + let definition_scope_id = definition.scope; + let old_scope = self.push_scope(ctx, true, definition_scope_id); let definition = &ctx.heap[id]; let body_id = definition.body; @@ -253,7 +255,9 @@ impl Visitor for PassValidationLinking { self.expr_parent = ExpressionParent::None; // Visit parameters and assign a unique scope ID - let old_scope = self.push_scope(ctx, ScopeAssociation::Definition(id.upcast())); + let definition = &ctx.heap[id]; + let definition_scope_id = definition.scope; + let old_scope = self.push_scope(ctx, true, definition_scope_id); let definition = &ctx.heap[id]; let body_id = definition.body; @@ -287,10 +291,11 @@ impl Visitor for PassValidationLinking { // Get end of block let block_stmt = &ctx.heap[id]; let end_block_id = block_stmt.end_block; + let scope_id = block_stmt.scope; // Traverse statements in block let statement_section = self.statement_buffer.start_section_initialized(&block_stmt.statements); - let old_scope = self.push_scope(ctx, ScopeAssociation::Block(id)); + let old_scope = self.push_scope(ctx, false, scope_id); assign_and_replace_next_stmt!(self, ctx, id.upcast()); for stmt_idx in 0..statement_section.len() { @@ -365,13 +370,13 @@ impl Visitor for PassValidationLinking { // test expression, not on if-statement itself. Hence the if statement // does not have a static subsequent statement. assign_then_erase_next_stmt!(self, ctx, id.upcast()); - let old_scope = self.push_scope(ctx, ScopeAssociation::If(id, true)); + let old_scope = self.push_scope(ctx, false, true_case.scope); self.visit_stmt(ctx, true_case.body)?; self.pop_scope(old_scope); assign_then_erase_next_stmt!(self, ctx, end_if_id.upcast()); if let Some(false_case) = false_case { - let old_scope = self.push_scope(ctx, ScopeAssociation::If(id, false)); + let old_scope = self.push_scope(ctx, false, false_case.scope); self.visit_stmt(ctx, false_case.body)?; self.pop_scope(old_scope); assign_then_erase_next_stmt!(self, ctx, end_if_id.upcast()); @@ -386,6 +391,7 @@ impl Visitor for PassValidationLinking { let end_while_id = stmt.end_while; let test_expr_id = stmt.test; let body_stmt_id = stmt.body; + let scope_id = stmt.scope; let old_while = self.in_while; self.in_while = id; @@ -402,7 +408,7 @@ impl Visitor for PassValidationLinking { assign_then_erase_next_stmt!(self, ctx, id.upcast()); self.expr_parent = ExpressionParent::None; - let old_scope = self.push_scope(ctx, ScopeAssociation::While(id)); + let old_scope = self.push_scope(ctx, false, scope_id); self.visit_stmt(ctx, body_stmt_id)?; self.pop_scope(old_scope); self.in_while = old_while; @@ -445,6 +451,8 @@ impl Visitor for PassValidationLinking { let sync_stmt = &ctx.heap[id]; let end_sync_id = sync_stmt.end_sync; let cur_sync_span = sync_stmt.span; + let scope_id = sync_stmt.scope; + if !self.in_sync.is_invalid() { // Nested synchronous statement let old_sync_span = ctx.heap[self.in_sync].span; @@ -471,7 +479,7 @@ impl Visitor for PassValidationLinking { let sync_body = ctx.heap[id].body; debug_assert!(self.in_sync.is_invalid()); self.in_sync = id; - let old_scope = self.push_scope(ctx, ScopeAssociation::Synchronous(id)); + let old_scope = self.push_scope(ctx, false, scope_id); self.visit_stmt(ctx, sync_body)?; self.pop_scope(old_scope); assign_and_replace_next_stmt!(self, ctx, end_sync_id.upcast()); @@ -532,29 +540,32 @@ impl Visitor for PassValidationLinking { // Visit the various arms in the select block let mut case_stmt_ids = self.statement_buffer.start_section(); + let mut case_scope_ids = self.scope_buffer.start_section(); let num_cases = select_stmt.cases.len(); for case in &select_stmt.cases { // We add them in pairs, so the subsequent for-loop retrieves in pairs case_stmt_ids.push(case.guard); case_stmt_ids.push(case.body); + case_scope_ids.push(case.scope); } assign_then_erase_next_stmt!(self, ctx, id.upcast()); - for idx in 0..num_cases { - let base_idx = 2 * idx; - let guard_id = case_stmt_ids[base_idx ]; - let case_body_id = case_stmt_ids[base_idx + 1]; + for index in 0..num_cases { + let base_index = 2 * index; + let guard_id = case_stmt_ids[base_index]; + let case_body_id = case_stmt_ids[base_index + 1]; + let case_scope_id = case_scope_ids[index]; // The guard statement ends up belonging to the block statement // following the arm. The reason we parse it separately is to // extract all of the "get" calls. - let old_scope = self.push_scope(ctx, ScopeAssociation::SelectCase(id, idx as u32)); + let old_scope = self.push_scope(ctx, false, case_scope_id); // Visit the guard of this arm debug_assert!(self.in_select_guard.is_invalid()); self.in_select_guard = id; - self.in_select_arm = idx as u32; + self.in_select_arm = index as u32; self.visit_stmt(ctx, guard_id)?; self.in_select_guard = SelectStatementId::new_invalid(); @@ -1453,71 +1464,29 @@ impl PassValidationLinking { /// sync statement or select statement's arm) then we won't do anything. /// In all cases the caller must call `pop_statement_scope` with the scope /// and relative scope position returned by this function. - fn push_scope(&mut self, ctx: &mut Ctx, association: ScopeAssociation) -> (ScopeId, i32) { - // Create new scope and assign as ScopeId to the specified associated - // statement. - let is_first_scope = match association { - ScopeAssociation::Definition(_) => true, - _ => false, - }; - - let old_scope_id = self.cur_scope.clone(); - let new_scope_id = ctx.heap.alloc_scope(|this| Scope{ - this, - parent: if is_first_scope { None } else { Some(old_scope_id) }, - nested: Vec::new(), - association, - variables: Vec::new(), - labels: Vec::new(), - relative_pos_in_parent: self.relative_pos_in_parent, - first_unique_id_in_scope: -1, - next_unique_id_in_scope: -1, - }); - - match association { - ScopeAssociation::Definition(definition_id) => { - let def = &mut ctx.heap[definition_id]; - match def { - Definition::Function(def) => def.scope = new_scope_id, - Definition::Component(def) => def.scope = new_scope_id, - _ => unreachable!(), - } - }, - ScopeAssociation::Block(stmt_id) => { - ctx.heap[stmt_id].scope = new_scope_id; - }, - ScopeAssociation::If(stmt_id, in_true_body) => { - let stmt = &mut ctx.heap[stmt_id]; - if in_true_body { - stmt.true_case.scope = new_scope_id; - } else { - let false_body = stmt.false_case.as_mut().unwrap(); - false_body.scope = new_scope_id; - } - }, - ScopeAssociation::While(stmt_id) => { - ctx.heap[stmt_id].scope = new_scope_id; - }, - ScopeAssociation::Synchronous(stmt_id) => { - ctx.heap[stmt_id].scope = new_scope_id; - }, - ScopeAssociation::SelectCase(stmt_id, case_index) => { - let select_stmt = &mut ctx.heap[stmt_id]; - select_stmt.cases[case_index as usize].scope = new_scope_id; - } + fn push_scope(&mut self, ctx: &mut Ctx, is_top_level_scope: bool, pushed_scope_id: ScopeId) -> (ScopeId, i32) { + // Set the properties of the pushed scope (it is already created during + // AST construction, but most values are not yet set to their correct + // values) + let old_scope_id = self.cur_scope; + + let scope = &mut ctx.heap[pushed_scope_id]; + if is_top_level_scope { + scope.parent = Some(old_scope_id); } + scope.relative_pos_in_parent = self.relative_pos_in_parent; + let old_relative_pos = self.relative_pos_in_parent; + self.relative_pos_in_parent = 0; + // Link up scopes - if !is_first_scope { + if !is_top_level_scope { let old_scope = &mut ctx.heap[old_scope_id]; - old_scope.nested.push(new_scope_id); + old_scope.nested.push(pushed_scope_id); } // Set as current traversal scope, then return old scope - self.cur_scope = new_scope_id; - - let old_relative_pos = self.relative_pos_in_parent; - self.relative_pos_in_parent = 0; + self.cur_scope = pushed_scope_id; return (old_scope_id, old_relative_pos) }