diff --git a/src/protocol/ast.rs b/src/protocol/ast.rs index b92a35cc3ebf01a92e04de55bbc52babd6d429a8..3978a7d30a793bf1c852e84768971e4221425a25 100644 --- a/src/protocol/ast.rs +++ b/src/protocol/ast.rs @@ -1362,6 +1362,8 @@ pub struct SelectCase { // ExpressionStatement. Nothing else is allowed by the initial parsing pub guard: StatementId, pub block: BlockStatementId, + // Phase 2: Validation and Linking + pub involved_ports: Vec<(CallExpressionId, ExpressionId)>, // call to `get` and its port argument } #[derive(Debug, Clone)] diff --git a/src/protocol/parser/pass_definitions.rs b/src/protocol/parser/pass_definitions.rs index 94de9499c7f17bf17a73332551296924f9d7953d..144ae43b5ebb17aec83f692e1f76a28f78fca8e6 100644 --- a/src/protocol/parser/pass_definitions.rs +++ b/src/protocol/parser/pass_definitions.rs @@ -713,7 +713,10 @@ impl PassDefinitions { }; consume_token(&module.source, iter, TokenKind::ArrowRight)?; let block = self.consume_block_or_wrapped_statement(module, iter, ctx)?; - cases.push(SelectCase{ guard, block }); + cases.push(SelectCase{ + guard, block, + involved_ports: Vec::with_capacity(1) + }); next = iter.next(); } diff --git a/src/protocol/parser/pass_typing.rs b/src/protocol/parser/pass_typing.rs index 6bc1fc4c18fd33b15b8772cf7e10ab17fff706ad..3df60af0a20e6220a55019fe3345fa30ca510434 100644 --- a/src/protocol/parser/pass_typing.rs +++ b/src/protocol/parser/pass_typing.rs @@ -1175,6 +1175,30 @@ impl Visitor for PassTyping { Ok(()) } + fn visit_select_stmt(&mut self, ctx: &mut Ctx, id: SelectStatementId) -> VisitorResult { + let select_stmt = &ctx.heap[id]; + + let mut section = self.stmt_buffer.start_section(); + let num_cases = select_stmt.cases.len(); + + for case in &select_stmt.cases { + section.push(case.guard); + section.push(case.block.upcast()); + } + + for case_index in 0..num_cases { + let base_index = 2 * case_index; + let guard_stmt_id = section[base_index ]; + let block_stmt_id = section[base_index + 1]; + + self.visit_stmt(ctx, guard_stmt_id); + self.visit_stmt(ctx, block_stmt_id); + } + section.forget(); + + Ok(()) + } + fn visit_return_stmt(&mut self, ctx: &mut Ctx, id: ReturnStatementId) -> VisitorResult { let return_stmt = &ctx.heap[id]; debug_assert_eq!(return_stmt.expressions.len(), 1); diff --git a/src/protocol/parser/pass_validation_linking.rs b/src/protocol/parser/pass_validation_linking.rs index c299a3b8e364e5e64cafb9cf106ee3f4ce6ecf38..621092216319cdae738794dfffda706abec1aa9b 100644 --- a/src/protocol/parser/pass_validation_linking.rs +++ b/src/protocol/parser/pass_validation_linking.rs @@ -84,6 +84,8 @@ pub(crate) struct PassValidationLinking { // `id.is_invalid()` returns true. in_sync: SynchronousStatementId, in_while: WhileStatementId, // to resolve labeled continue/break + in_select_guard: SelectStatementId, // for detection/rejection of builtin calls + in_select_arm: u32, in_test_expr: StatementId, // wrapping if/while stmt id in_binding_expr: BindingExpressionId, // to resolve variable expressions in_binding_expr_lhs: bool, @@ -115,6 +117,8 @@ impl PassValidationLinking { Self{ in_sync: SynchronousStatementId::new_invalid(), in_while: WhileStatementId::new_invalid(), + in_select_guard: SelectStatementId::new_invalid(), + in_select_arm: 0, in_test_expr: StatementId::new_invalid(), in_binding_expr: BindingExpressionId::new_invalid(), in_binding_expr_lhs: false, @@ -476,18 +480,40 @@ impl Visitor for PassValidationLinking { assign_then_erase_next_stmt!(self, ctx, id.upcast()); // Link up the "Select" with the "EndSelect". If there are no cases then - // runtime will pick the "EndSelect" immediately. + // runtime should pick the "EndSelect" immediately. + for idx in 0..num_cases { let base_idx = 2 * idx; let guard_id = case_stmt_ids[base_idx ]; let arm_block_id = case_stmt_ids[base_idx + 1]; + // Visit the guard of this arm + debug_assert!(self.in_select_guard.is_invalid()); + self.in_select_guard = id; + self.in_select_arm = idx as u32; self.visit_stmt(ctx, guard_id)?; + self.in_select_guard = SelectStatementId::new_invalid(); + + // If the arm declares a variable, then add it to the variables of + // this arm's code block + match &ctx.heap[guard_id] { + Statement::Local(LocalStatement::Memory(stmt)) => { + debug_assert_eq!(ctx.heap[arm_block_id].as_block().this.upcast(), arm_block_id); // backwards way of saying arm_block_id is a BlockStatementId + let block_stmt = BlockStatementId(arm_block_id); + let block_scope = Scope::Regular(block_stmt); + self.checked_at_single_scope_add_local(ctx, block_scope, -1, stmt.variable)?; + }, + Statement::Expression(_) => {}, + _ => unreachable!(), // just to be sure the parser produced the expected AST + } + + // Visit the code associated with the guard self.visit_stmt(ctx, arm_block_id)?; assign_then_erase_next_stmt!(self, ctx, end_select_id.upcast()); } + self.in_select_guard = SelectStatementId::new_invalid(); self.prev_stmt = end_select_id.upcast(); Ok(()) } @@ -1101,7 +1127,7 @@ impl Visitor for PassValidationLinking { } fn visit_call_expr(&mut self, ctx: &mut Ctx, id: CallExpressionId) -> VisitorResult { - let call_expr = &mut ctx.heap[id]; + let call_expr = &ctx.heap[id]; if let Some(span) = self.must_be_assignable { return Err(ParseError::new_error_str_at_span( @@ -1111,59 +1137,49 @@ impl Visitor for PassValidationLinking { // Check whether the method is allowed to be called within the code's // context (in sync, definition type, etc.) - let mut expected_wrapping_new_stmt = false; - match &mut call_expr.method { + let mut expecting_wrapping_new_stmt = false; + let mut expecting_primitive_def = false; + let mut expecting_wrapping_sync_stmt = false; + let mut expecting_no_select_stmt = false; + + match call_expr.method { Method::Get => { - if !self.def_type.is_primitive() { - let call_span = call_expr.func_span; - return Err(ParseError::new_error_str_at_span( - &ctx.module().source, call_span, - "a call to 'get' may only occur in primitive component definitions" - )); - } - if self.in_sync.is_invalid() { - let call_span = call_expr.func_span; - return Err(ParseError::new_error_str_at_span( - &ctx.module().source, call_span, - "a call to 'get' may only occur inside synchronous blocks" - )); + expecting_primitive_def = true; + expecting_wrapping_sync_stmt = true; + if !self.in_select_guard.is_invalid() { + // In a select guard. Take the argument (i.e. the port we're + // retrieving from) and add it to the list of involved ports + // of the guard + if call_expr.arguments.len() == 1 { + // We're checking the number of arguments later, for now + // assume it is correct. + let argument = call_expr.arguments[0]; + let select_stmt = &mut ctx.heap[self.in_select_guard]; + let select_case = &mut select_stmt.cases[self.in_select_arm as usize]; + select_case.involved_ports.push((id, argument)); + } } }, Method::Put => { - if !self.def_type.is_primitive() { - let call_span = call_expr.func_span; - return Err(ParseError::new_error_str_at_span( - &ctx.module().source, call_span, - "a call to 'put' may only occur in primitive component definitions" - )); - } - if self.in_sync.is_invalid() { + expecting_primitive_def = true; + expecting_wrapping_sync_stmt = true; + if !self.in_select_guard.is_invalid() { let call_span = call_expr.func_span; return Err(ParseError::new_error_str_at_span( &ctx.module().source, call_span, - "a call to 'put' may only occur inside synchronous blocks" + "a call to 'put' may not occur in a select statement's guard" )); } }, Method::Fires => { - if !self.def_type.is_primitive() { - let call_span = call_expr.func_span; - return Err(ParseError::new_error_str_at_span( - &ctx.module().source, call_span, - "a call to 'fires' may only occur in primitive component definitions" - )); - } - if self.in_sync.is_invalid() { - let call_span = call_expr.func_span; - return Err(ParseError::new_error_str_at_span( - &ctx.module().source, call_span, - "a call to 'fires' may only occur inside synchronous blocks" - )); - } + expecting_primitive_def = true; + expecting_wrapping_sync_stmt = true; }, Method::Create => {}, Method::Length => {}, Method::Assert => { + expecting_wrapping_sync_stmt = true; + expecting_no_select_stmt = true; if self.def_type.is_function() { let call_span = call_expr.func_span; return Err(ParseError::new_error_str_at_span( @@ -1171,22 +1187,43 @@ impl Visitor for PassValidationLinking { "assert statement may only occur in components" )); } - if self.in_sync.is_invalid() { - let call_span = call_expr.func_span; - return Err(ParseError::new_error_str_at_span( - &ctx.module().source, call_span, - "assert statements may only occur inside synchronous blocks" - )); - } }, Method::Print => {}, Method::UserFunction => {}, Method::UserComponent => { - expected_wrapping_new_stmt = true; + expecting_wrapping_new_stmt = true; }, } - if expected_wrapping_new_stmt { + let call_expr = &mut ctx.heap[id]; + + fn get_span_and_name<'a>(ctx: &'a Ctx, id: CallExpressionId) -> (InputSpan, String) { + let call = &ctx.heap[id]; + let span = call.func_span; + let name = String::from_utf8_lossy(ctx.module().source.section_at_span(span)).to_string(); + return (span, name); + } + if expecting_primitive_def { + if !self.def_type.is_primitive() { + let (call_span, func_name) = get_span_and_name(ctx, id); + return Err(ParseError::new_error_at_span( + &ctx.module().source, call_span, + format!("a call to '{}' may only occur in primitive component definitions", func_name) + )); + } + } + + if expecting_wrapping_sync_stmt { + if self.in_sync.is_invalid() { + let (call_span, func_name) = get_span_and_name(ctx, id); + return Err(ParseError::new_error_at_span( + &ctx.module().source, call_span, + format!("a call to '{}' may only occur inside synchronous blocks", func_name) + )) + } + } + + if expecting_wrapping_new_stmt { if !self.expr_parent.is_new() { let call_span = call_expr.func_span; return Err(ParseError::new_error_str_at_span( @@ -1630,7 +1667,7 @@ impl PassValidationLinking { for other_local_id in &block.locals { let other_local = &ctx.heap[*other_local_id]; if local.this != other_local.this && - relative_pos >= other_local.relative_pos_in_block && + // relative_pos >= other_local.relative_pos_in_block && local.identifier == other_local.identifier { // Collision return Err(