From 949789bac3cf76b048587761e7dd79afadeab935 2021-12-20 09:53:17 From: mh Date: 2021-12-20 09:53:17 Subject: [PATCH] Parsing of select statement --- diff --git a/src/protocol/ast.rs b/src/protocol/ast.rs index a8b18a801eef40f745879bd48630af59c8159dc3..7657a9390ebffa00927a9cdc1bc6ae71225e0df2 100644 --- a/src/protocol/ast.rs +++ b/src/protocol/ast.rs @@ -1357,7 +1357,7 @@ pub struct SelectStatement { #[derive(Debug, Clone)] pub struct SelectCase { - pub guard_var: Option, // optional memory declaration + pub guard_var: MemoryStatementId, // invalid ID if there is no declaration of a variable pub guard_expr: ExpressionStatementId, // if `guard_var.is_some()`, then always assignment expression pub block: BlockStatementId, } diff --git a/src/protocol/ast_printer.rs b/src/protocol/ast_printer.rs index 8a059052ecad78939c48f82b234610db38ebccbd..1075d78961a11fdf554c641ea1d37715f0fc50bb 100644 --- a/src/protocol/ast_printer.rs +++ b/src/protocol/ast_printer.rs @@ -541,9 +541,9 @@ impl ASTWriter { let indent3 = indent2 + 1; let indent4 = indent3 + 1; for case in &stmt.cases { - if let Some(guard_var_id) = case.guard_var { + if !case.guard_var.is_invalid() { self.kv(indent3).with_s_key("GuardStatement"); - self.write_stmt(heap, guard_var_id.upcast().upcast(), indent4); + self.write_stmt(heap, case.guard_var.upcast().upcast(), indent4); } else { self.kv(indent3).with_s_key("GuardStatement").with_s_val("None"); } diff --git a/src/protocol/parser/pass_definitions.rs b/src/protocol/parser/pass_definitions.rs index af7df9289e33d4f14b041012aed4b6d3581b76fe..885c5aa4a83626d35fed98fa687fed9be50e92ff 100644 --- a/src/protocol/parser/pass_definitions.rs +++ b/src/protocol/parser/pass_definitions.rs @@ -487,7 +487,8 @@ impl PassDefinitions { // Two fallback possibilities: the first one is a memory // declaration, the other one is to parse it as a normal // expression. This is a bit ugly. - if let Some((memory_stmt_id, assignment_stmt_id)) = self.maybe_consume_memory_statement(module, iter, ctx)? { + if let Some((memory_stmt_id, assignment_stmt_id)) = self.maybe_consume_memory_statement_without_semicolon(module, iter, ctx)? { + consume_token(&module.source, iter, TokenKind::SemiColon)?; section.push(memory_stmt_id.upcast().upcast()); section.push(assignment_stmt_id.upcast()); } else { @@ -497,7 +498,8 @@ impl PassDefinitions { } } else if next == TokenKind::OpenParen { // Same as above: memory statement or normal expression - if let Some((memory_stmt_id, assignment_stmt_id)) = self.maybe_consume_memory_statement(module, iter, ctx)? { + if let Some((memory_stmt_id, assignment_stmt_id)) = self.maybe_consume_memory_statement_without_semicolon(module, iter, ctx)? { + consume_token(&module.source, iter, TokenKind::SemiColon); section.push(memory_stmt_id.upcast().upcast()); section.push(assignment_stmt_id.upcast()); } else { @@ -691,28 +693,34 @@ impl PassDefinitions { consume_token(&module.source, iter, TokenKind::OpenCurly)?; let mut cases = Vec::new(); - consume_comma_separated_until( - TokenKind::CloseCurly, &module.source, iter, ctx, - |source, iter, ctx| { - // A select arm starts with a guard, being something of the form - // `defined_var = get(port)`, `get(port)` or - // `Type var = get(port)`. So: - let (guard_var, guard_expr) = match self.maybe_consume_memory_statement(module, iter, ctx)? { - Some((guard_var, guard_expr)) => { - (Some(guard_var), guard_expr) - }, - None => { - let guard_expr = self.consume_expression_statement(module, iter, ctx)?; - (None, guard_expr) - }, - }; - consume_token(source, iter, TokenKind::ArrowRight)?; - let block = self.consume_block_or_wrapped_statement(module, iter, ctx)?; + let mut next = iter.next(); - Ok(SelectCase{ guard_var, guard_expr, block }) - }, - &mut cases, "select arm", None - )?; + while Some(TokenKind::CloseCurly) != next { + let (guard_var, guard_expr) = match self.maybe_consume_memory_statement_without_semicolon(module, iter, ctx)? { + Some(guard_var_and_expr) => guard_var_and_expr, + None => { + let start_pos = iter.last_valid_pos(); + let expr = self.consume_expression(module, iter, ctx)?; + let end_pos = iter.last_valid_pos(); + + let guard_expr = ctx.heap.alloc_expression_statement(|this| ExpressionStatement{ + this, + span: InputSpan::from_positions(start_pos, end_pos), + expression: expr, + next: StatementId::new_invalid(), + }); + + (MemoryStatementId::new_invalid(), guard_expr) + }, + }; + consume_token(&module.source, iter, TokenKind::ArrowRight)?; + let block = self.consume_block_or_wrapped_statement(module, iter, ctx)?; + cases.push(SelectCase{ guard_var, guard_expr, block }); + + next = iter.next(); + } + + consume_token(&module.source, iter, TokenKind::CloseCurly)?; Ok(ctx.heap.alloc_select_statement(|this| SelectStatement{ this, @@ -917,7 +925,7 @@ impl PassDefinitions { Ok(()) } - fn maybe_consume_memory_statement( + fn maybe_consume_memory_statement_without_semicolon( &mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx ) -> Result, ParseError> { // This is a bit ugly. It would be nicer if we could somehow @@ -942,7 +950,6 @@ impl PassDefinitions { let initial_expr_begin_pos = iter.last_valid_pos(); let initial_expr_id = self.consume_expression(module, iter, ctx)?; let initial_expr_end_pos = iter.last_valid_pos(); - consume_token(&module.source, iter, TokenKind::SemiColon)?; // Allocate the memory statement with the variable let local_id = ctx.heap.alloc_variable(|this| Variable{ diff --git a/src/protocol/parser/pass_validation_linking.rs b/src/protocol/parser/pass_validation_linking.rs index df9d4551039e33aced60cb296859e92fd23c9d79..0be50a5a640db03eed0ff2175564a511314ab7fc 100644 --- a/src/protocol/parser/pass_validation_linking.rs +++ b/src/protocol/parser/pass_validation_linking.rs @@ -436,6 +436,59 @@ impl Visitor for PassValidationLinking { Ok(()) } + fn visit_select_stmt(&mut self, ctx: &mut Ctx, id: SelectStatementId) -> VisitorResult { + let select_stmt = &ctx.heap[id]; + let end_select_id = select_stmt.end_select; + + // Select statements may only occur inside sync blocks + if self.in_sync.is_invalid() { + return Err(ParseError::new_error_str_at_span( + &ctx.module().source, select_stmt.span, + "select statements may only occur inside sync blocks" + )); + } + + if !self.def_type.is_primitive() { + return Err(ParseError::new_error_str_at_span( + &ctx.module().source, select_stmt.span, + "select statements may only be used in primitive components" + )); + } + + // Visit the various arms in the select block + // note: three statements per case, so we lookup as `3 * index + offset` + let mut case_stmt_ids = self.statement_buffer.start_section(); + let num_cases = select_stmt.cases.len(); + for case in &select_stmt.cases { + case_stmt_ids.push(case.guard_var.upcast().upcast()); + case_stmt_ids.push(case.guard_expr.upcast()); + case_stmt_ids.push(case.block.upcast()); + } + + assign_then_erase_next_stmt!(self, ctx, id.upcast()); + + // Link up the "Select" with the "EndSelect". If there are no cases then + // runtime will pick the "EndSelect" immediately. + for idx in 0..num_cases { + let base_idx = 3 * idx; + let guard_var_id = case_stmt_ids[base_idx ]; + let guard_expr_id = case_stmt_ids[base_idx + 1]; + let arm_block_id = case_stmt_ids[base_idx + 2]; + + if !guard_var_id.is_invalid() { + self.visit_stmt(ctx, guard_var_id)?; + } + + self.visit_stmt(ctx, guard_expr_id)?; + self.visit_stmt(ctx, arm_block_id)?; + + assign_then_erase_next_stmt!(self, ctx, end_select_id.upcast()); + } + + self.prev_stmt = end_select_id.upcast(); + Ok(()) + } + fn visit_return_stmt(&mut self, ctx: &mut Ctx, id: ReturnStatementId) -> VisitorResult { // Check if "return" occurs within a function let stmt = &ctx.heap[id]; @@ -1219,8 +1272,8 @@ impl Visitor for PassValidationLinking { true } Expression::Literal(lit_expr) => { - // Only struct, unions and arrays can have - // subexpressions, so we're always fine + // Only struct, unions, tuples and arrays can + // have subexpressions, so we're always fine if cfg!(debug_assertions) { match lit_expr.value { Literal::Struct(_) | Literal::Union(_) | Literal::Array(_) | Literal::Tuple(_) => {}, diff --git a/src/protocol/tests/mod.rs b/src/protocol/tests/mod.rs index 70eb7af70543577dd8a83fc315e8ce7a29221c1e..ae0f331b905942da4bd4f3c072c5fb71b6d49e1b 100644 --- a/src/protocol/tests/mod.rs +++ b/src/protocol/tests/mod.rs @@ -12,7 +12,7 @@ */ mod utils; -mod lexer; +mod parser_after_tokenizing; mod parser_binding; mod parser_imports; mod parser_inference; diff --git a/src/protocol/tests/lexer.rs b/src/protocol/tests/parser_after_tokenizing.rs similarity index 99% rename from src/protocol/tests/lexer.rs rename to src/protocol/tests/parser_after_tokenizing.rs index 2c7f4ef5d81bdcba5631a065803ca7cd1d10ca7c..c05ffc0589703b6b2593f44be9c58e3c2a6be505 100644 --- a/src/protocol/tests/lexer.rs +++ b/src/protocol/tests/parser_after_tokenizing.rs @@ -1,4 +1,4 @@ -/// lexer.rs +/// parser_after_tokenizing /// /// Simple tests for the lexer. Only tests the lexing of the input source and /// the resulting AST without relying on the validation/typing pass diff --git a/src/protocol/tests/parser_validation.rs b/src/protocol/tests/parser_validation.rs index 0462c7e1e7f81070d022ab9fa36ad21f66d6203c..7d516ce640715f9c74acc159ac56a7ed8307bef9 100644 --- a/src/protocol/tests/parser_validation.rs +++ b/src/protocol/tests/parser_validation.rs @@ -528,4 +528,79 @@ fn test_polymorph_array_types() { ).for_struct("Bar", |s| { s .for_field("inputs", |f| { f.assert_parser_type("in[]"); }); }); +} + +#[test] +fn test_correct_modifying_operators() { + // Not testing the types, just that it parses + Tester::new_single_source_expect_ok( + "valid uses", + " + func f() -> u32 { + auto a = 5; + a += 2; a -= 2; a *= 2; a /= 2; a %= 2; + a <<= 2; a >>= 2; + a |= 2; a &= 2; a ^= 2; + return a; + } + " + ); +} + +#[test] +fn test_incorrect_modifying_operators() { + Tester::new_single_source_expect_err( + "wrong declaration", + "func f() -> u8 { auto a += 2; return a; }" + ).error(|e| { e.assert_msg_has(0, "expected '='"); }); + + Tester::new_single_source_expect_err( + "inside function", + "func f(u32 a) -> u32 { auto b = 0; auto c = f(a += 2); }" + ).error(|e| { e.assert_msg_has(0, "assignments are statements"); }); + + Tester::new_single_source_expect_err( + "inside tuple", + "func f(u32 a) -> u32 { auto b = (a += 2, a /= 2); return 0; }" + ).error(|e| { e.assert_msg_has(0, "assignments are statements"); }); +} + +#[test] +fn test_correct_select_statement() { + Tester::new_single_source_expect_ok( + "correct single-use", " + primitive f() { + channel unused_output -> input; + u32 outer_value = 0; + sync select { + outer_value = get(input) -> outer_value = 0; + auto new_value = get(input) -> { + f(); + outer_value = new_value; + } + get(input) + get(input) -> + outer_value = 8; + get(input) -> + {} + outer_value %= get(input) -> { + outer_value *= outer_value; + auto new_value = get(input); + outer_value += new_value; + } + } + } + " + ); +} + +#[test] +fn test_incorrect_select_statement() { + Tester::new_single_source_expect_err( + "outside sync", + "func f() -> u32 { select {} return 0; }" + ).error(|e| { e + .assert_num(1) + .assert_occurs_at(0, "select") + .assert_msg_has(0, "inside sync blocks"); + }); } \ No newline at end of file