diff --git a/src/protocol/ast.rs b/src/protocol/ast.rs index 95195aef473b46c535fb2b034a712da10bf953ca..da5ba6db6c6a05dde8bafa165679eda350dd7256 100644 --- a/src/protocol/ast.rs +++ b/src/protocol/ast.rs @@ -137,6 +137,8 @@ define_new_ast_id!(SynchronousStatementId, StatementId, index(SynchronousStateme define_new_ast_id!(EndSynchronousStatementId, StatementId, index(EndSynchronousStatement, Statement::EndSynchronous, statements), alloc(alloc_end_synchronous_statement)); define_new_ast_id!(ForkStatementId, StatementId, index(ForkStatement, Statement::Fork, statements), alloc(alloc_fork_statement)); define_new_ast_id!(EndForkStatementId, StatementId, index(EndForkStatement, Statement::EndFork, statements), alloc(alloc_end_fork_statement)); +define_new_ast_id!(SelectStatementId, StatementId, index(SelectStatement, Statement::Select, statements), alloc(alloc_select_statement)); +define_new_ast_id!(EndSelectStatementId, StatementId, index(EndSelectStatement, Statement::EndSelect, statements), alloc(alloc_end_select_statement)); define_new_ast_id!(ReturnStatementId, StatementId, index(ReturnStatement, Statement::Return, statements), alloc(alloc_return_statement)); define_new_ast_id!(GotoStatementId, StatementId, index(GotoStatement, Statement::Goto, statements), alloc(alloc_goto_statement)); define_new_ast_id!(NewStatementId, StatementId, index(NewStatement, Statement::New, statements), alloc(alloc_new_statement)); @@ -697,21 +699,32 @@ impl<'a> Iterator for ConcreteTypeIter<'a> { pub enum Scope { Definition(DefinitionId), Regular(BlockStatementId), - Synchronous((SynchronousStatementId, BlockStatementId)), + Synchronous(SynchronousStatementId, BlockStatementId), } impl Scope { + pub(crate) fn new_invalid() -> Scope { + return Scope::Definition(DefinitionId::new_invalid()); + } + + pub(crate) fn is_invalid(&self) -> bool { + match self { + Scope::Definition(id) => id.is_invalid(), + _ => false, + } + } + pub fn is_block(&self) -> bool { match &self { Scope::Definition(_) => false, Scope::Regular(_) => true, - Scope::Synchronous(_) => true, + Scope::Synchronous(_, _) => true, } } pub fn to_block(&self) -> BlockStatementId { match &self { Scope::Regular(id) => *id, - Scope::Synchronous((_, id)) => *id, + Scope::Synchronous(_, id) => *id, _ => panic!("unable to get BlockStatement from Scope") } } @@ -724,13 +737,15 @@ impl Scope { pub struct ScopeNode { pub parent: Scope, pub nested: Vec, + pub relative_pos_in_parent: i32, } impl ScopeNode { pub(crate) fn new_invalid() -> Self { ScopeNode{ - parent: Scope::Definition(DefinitionId::new_invalid()), + parent: Scope::new_invalid(), nested: Vec::new(), + relative_pos_in_parent: -1, } } } @@ -750,7 +765,7 @@ pub struct Variable { pub parser_type: ParserType, pub identifier: Identifier, // Validator/linker - pub relative_pos_in_block: u32, + pub relative_pos_in_block: i32, pub unique_id_in_scope: i32, // Temporary fix until proper bytecode/asm is generated } @@ -1068,6 +1083,8 @@ pub enum Statement { EndSynchronous(EndSynchronousStatement), Fork(ForkStatement), EndFork(EndForkStatement), + Select(SelectStatement), + EndSelect(EndSelectStatement), Return(ReturnStatement), Goto(GotoStatement), New(NewStatement), @@ -1112,11 +1129,17 @@ impl Statement { Statement::Continue(v) => v.span, Statement::Synchronous(v) => v.span, Statement::Fork(v) => v.span, + Statement::Select(v) => v.span, Statement::Return(v) => v.span, Statement::Goto(v) => v.span, Statement::New(v) => v.span, Statement::Expression(v) => v.span, - Statement::EndBlock(_) | Statement::EndIf(_) | Statement::EndWhile(_) | Statement::EndSynchronous(_) | Statement::EndFork(_) => unreachable!(), + Statement::EndBlock(_) + | Statement::EndIf(_) + | Statement::EndWhile(_) + | Statement::EndSynchronous(_) + | Statement::EndFork(_) + | Statement::EndSelect(_) => unreachable!(), } } pub fn link_next(&mut self, next: StatementId) { @@ -1131,6 +1154,7 @@ impl Statement { Statement::EndWhile(stmt) => stmt.next = next, Statement::EndSynchronous(stmt) => stmt.next = next, Statement::EndFork(stmt) => stmt.next = next, + Statement::EndSelect(stmt) => stmt.next = next, Statement::New(stmt) => stmt.next = next, Statement::Expression(stmt) => stmt.next = next, Statement::Return(_) @@ -1138,6 +1162,7 @@ impl Statement { | Statement::Continue(_) | Statement::Synchronous(_) | Statement::Fork(_) + | Statement::Select(_) | Statement::Goto(_) | Statement::While(_) | Statement::Labeled(_) @@ -1158,7 +1183,6 @@ pub struct BlockStatement { pub scope_node: ScopeNode, pub first_unique_id_in_scope: i32, // Temporary fix until proper bytecode/asm is generated pub next_unique_id_in_scope: i32, // Temporary fix until proper bytecode/asm is generated - pub relative_pos_in_parent: u32, pub locals: Vec, pub labels: Vec, pub next: StatementId, @@ -1212,6 +1236,7 @@ pub struct MemoryStatement { // Phase 1: parser pub span: InputSpan, pub variable: VariableId, + pub initial_expr: AssignmentExpressionId, // Phase 2: linker pub next: StatementId, } @@ -1229,7 +1254,7 @@ pub struct ChannelStatement { pub from: VariableId, // output pub to: VariableId, // input // Phase 2: linker - pub relative_pos_in_block: u32, + pub relative_pos_in_block: i32, pub next: StatementId, } @@ -1240,7 +1265,7 @@ pub struct LabeledStatement { pub label: Identifier, pub body: StatementId, // Phase 2: linker - pub relative_pos_in_block: u32, + pub relative_pos_in_block: i32, pub in_sync: SynchronousStatementId, // may be invalid } @@ -1288,7 +1313,7 @@ pub struct BreakStatement { pub span: InputSpan, // of the "break" keyword pub label: Option, // Phase 2: linker - pub target: Option, + pub target: EndWhileStatementId, // invalid if not yet set } #[derive(Debug, Clone)] @@ -1298,7 +1323,7 @@ pub struct ContinueStatement { pub span: InputSpan, // of the "continue" keyword pub label: Option, // Phase 2: linker - pub target: Option, + pub target: WhileStatementId, // invalid if not yet set } #[derive(Debug, Clone)] @@ -1335,6 +1360,31 @@ pub struct EndForkStatement { pub next: StatementId, } +#[derive(Debug, Clone)] +pub struct SelectStatement { + pub this: SelectStatementId, + pub span: InputSpan, // of the "select" keyword + pub cases: Vec, + pub end_select: EndSelectStatementId, +} + +#[derive(Debug, Clone)] +pub struct SelectCase { + // The guard statement of a `select` is either a MemoryStatement or an + // ExpressionStatement. Nothing else is allowed by the initial parsing + pub guard: StatementId, + pub block: BlockStatementId, + // Phase 2: Validation and Linking + pub involved_ports: Vec<(CallExpressionId, ExpressionId)>, // call to `get` and its port argument +} + +#[derive(Debug, Clone)] +pub struct EndSelectStatement { + pub this: EndSelectStatementId, + pub start_select: SelectStatementId, + pub next: StatementId, +} + #[derive(Debug, Clone)] pub struct ReturnStatement { pub this: ReturnStatementId, @@ -1350,7 +1400,7 @@ pub struct GotoStatement { pub span: InputSpan, // of the "goto" keyword pub label: Identifier, // Phase 2: linker - pub target: Option, + pub target: LabeledStatementId, // invalid if not yet set } #[derive(Debug, Clone)] @@ -1376,6 +1426,7 @@ pub struct ExpressionStatement { #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum ExpressionParent { None, // only set during initial parsing + Memory(MemoryStatementId), If(IfStatementId), While(WhileStatementId), Return(ReturnStatementId), diff --git a/src/protocol/ast_printer.rs b/src/protocol/ast_printer.rs index 6bdc0dd1171d61344c781fe57b0ee3e296f35098..5529cd3dde91d3b9834ab3cc4ecce3d8641eef9d 100644 --- a/src/protocol/ast_printer.rs +++ b/src/protocol/ast_printer.rs @@ -38,6 +38,8 @@ const PREFIX_SYNC_STMT_ID: &'static str = "SSyn"; const PREFIX_ENDSYNC_STMT_ID: &'static str = "SESy"; const PREFIX_FORK_STMT_ID: &'static str = "SFrk"; const PREFIX_END_FORK_STMT_ID: &'static str = "SEFk"; +const PREFIX_SELECT_STMT_ID: &'static str = "SSel"; +const PREFIX_END_SELECT_STMT_ID: &'static str = "SESl"; const PREFIX_RETURN_STMT_ID: &'static str = "SRet"; const PREFIX_ASSERT_STMT_ID: &'static str = "SAsr"; const PREFIX_GOTO_STMT_ID: &'static str = "SGot"; @@ -401,7 +403,7 @@ impl ASTWriter { self.kv(indent2).with_s_key("EndBlockID").with_disp_val(&stmt.end_block.0.index); self.kv(indent2).with_s_key("FirstUniqueScopeID").with_disp_val(&stmt.first_unique_id_in_scope); self.kv(indent2).with_s_key("NextUniqueScopeID").with_disp_val(&stmt.next_unique_id_in_scope); - self.kv(indent2).with_s_key("RelativePos").with_disp_val(&stmt.relative_pos_in_parent); + self.kv(indent2).with_s_key("RelativePos").with_disp_val(&stmt.scope_node.relative_pos_in_parent); self.kv(indent2).with_s_key("Statements"); for stmt_id in &stmt.statements { @@ -431,6 +433,8 @@ impl ASTWriter { self.kv(indent2).with_s_key("Variable"); self.write_variable(heap, stmt.variable, indent3); + self.kv(indent2).with_s_key("InitialValue"); + self.write_expr(heap, stmt.initial_expr.upcast(), indent3); self.kv(indent2).with_s_key("Next").with_disp_val(&stmt.next.index); } } @@ -490,7 +494,7 @@ impl ASTWriter { self.kv(indent2).with_s_key("Label") .with_opt_identifier_val(stmt.label.as_ref()); self.kv(indent2).with_s_key("Target") - .with_opt_disp_val(stmt.target.as_ref().map(|v| &v.0.index)); + .with_disp_val(&stmt.target.0.index); }, Statement::Continue(stmt) => { self.kv(indent).with_id(PREFIX_CONTINUE_STMT_ID, stmt.this.0.index) @@ -498,7 +502,7 @@ impl ASTWriter { self.kv(indent2).with_s_key("Label") .with_opt_identifier_val(stmt.label.as_ref()); self.kv(indent2).with_s_key("Target") - .with_opt_disp_val(stmt.target.as_ref().map(|v| &v.0.index)); + .with_disp_val(&stmt.target.0.index); }, Statement::Synchronous(stmt) => { self.kv(indent).with_id(PREFIX_SYNC_STMT_ID, stmt.this.0.index) @@ -530,6 +534,27 @@ impl ASTWriter { .with_s_key("EndFork"); self.kv(indent2).with_s_key("StartFork").with_disp_val(&stmt.start_fork.0.index); self.kv(indent2).with_s_key("Next").with_disp_val(&stmt.next.index); + }, + Statement::Select(stmt) => { + self.kv(indent).with_id(PREFIX_SELECT_STMT_ID, stmt.this.0.index) + .with_s_key("Select"); + self.kv(indent2).with_s_key("EndSelect").with_disp_val(&stmt.end_select.0.index); + self.kv(indent2).with_s_key("Cases"); + let indent3 = indent2 + 1; + let indent4 = indent3 + 1; + for case in &stmt.cases { + self.kv(indent3).with_s_key("Guard"); + self.write_stmt(heap, case.guard, indent4); + + self.kv(indent3).with_s_key("Block"); + self.write_stmt(heap, case.block.upcast(), indent4); + } + }, + Statement::EndSelect(stmt) => { + self.kv(indent).with_id(PREFIX_END_SELECT_STMT_ID, stmt.this.0.index) + .with_s_key("EndSelect"); + self.kv(indent2).with_s_key("StartSelect").with_disp_val(&stmt.start_select.0.index); + self.kv(indent2).with_s_key("Next").with_disp_val(&stmt.next.index); } Statement::Return(stmt) => { self.kv(indent).with_id(PREFIX_RETURN_STMT_ID, stmt.this.0.index) @@ -544,7 +569,7 @@ impl ASTWriter { .with_s_key("Goto"); self.kv(indent2).with_s_key("Label").with_identifier_val(&stmt.label); self.kv(indent2).with_s_key("Target") - .with_opt_disp_val(stmt.target.as_ref().map(|v| &v.0.index)); + .with_disp_val(&stmt.target.0.index); }, Statement::New(stmt) => { self.kv(indent).with_id(PREFIX_NEW_STMT_ID, stmt.this.0.index) @@ -991,6 +1016,7 @@ fn write_expression_parent(target: &mut String, parent: &ExpressionParent) { *target = match parent { EP::None => String::from("None"), + EP::Memory(id) => format!("MemStmt({})", id.0.0.index), EP::If(id) => format!("IfStmt({})", id.0.index), EP::While(id) => format!("WhileStmt({})", id.0.index), EP::Return(id) => format!("ReturnStmt({})", id.0.index), diff --git a/src/protocol/eval/executor.rs b/src/protocol/eval/executor.rs index 130de897237dfa1854d68ef1225d452e527db3b2..feab9c135698e49f53dba5985a928d1644f48478 100644 --- a/src/protocol/eval/executor.rs +++ b/src/protocol/eval/executor.rs @@ -825,8 +825,13 @@ impl Prompt { Statement::Local(stmt) => { match stmt { LocalStatement::Memory(stmt) => { - let variable = &heap[stmt.variable]; - self.store.write(ValueId::Stack(variable.unique_id_in_scope as u32), Value::Unassigned); + if cfg!(debug_assertions) { + let variable = &heap[stmt.variable]; + debug_assert!(match self.store.read_ref(ValueId::Stack(variable.unique_id_in_scope as u32)) { + Value::Unassigned => false, + _ => true, + }); + } cur_frame.position = stmt.next; Ok(EvalContinuation::Stepping) @@ -891,12 +896,12 @@ impl Prompt { Ok(EvalContinuation::Stepping) }, Statement::Break(stmt) => { - cur_frame.position = stmt.target.unwrap().upcast(); + cur_frame.position = stmt.target.upcast(); Ok(EvalContinuation::Stepping) }, Statement::Continue(stmt) => { - cur_frame.position = stmt.target.unwrap().upcast(); + cur_frame.position = stmt.target.upcast(); Ok(EvalContinuation::Stepping) }, @@ -936,7 +941,15 @@ impl Prompt { cur_frame.position = stmt.next; Ok(EvalContinuation::Stepping) - } + }, + Statement::Select(_stmt) => { + todo!("implement select evaluation") + }, + Statement::EndSelect(stmt) => { + cur_frame.position = stmt.next; + + Ok(EvalContinuation::Stepping) + }, Statement::Return(_stmt) => { debug_assert!(heap[cur_frame.definition].is_function()); debug_assert_eq!(cur_frame.expr_values.len(), 1, "expected one expr value for return statement"); @@ -979,7 +992,7 @@ impl Prompt { return Ok(EvalContinuation::Stepping); }, Statement::Goto(stmt) => { - cur_frame.position = stmt.target.unwrap().upcast(); + cur_frame.position = stmt.target.upcast(); Ok(EvalContinuation::Stepping) }, @@ -1041,6 +1054,16 @@ impl Prompt { let stmt = &heap[cur_frame.position]; match stmt { + Statement::Local(stmt) => { + if let LocalStatement::Memory(stmt) = stmt { + // Setup as unassigned, when we execute the memory + // statement (after evaluating expression), it should no + // longer be `Unassigned`. + let variable = &heap[stmt.variable]; + self.store.write(ValueId::Stack(variable.unique_id_in_scope as u32), Value::Unassigned); + cur_frame.prepare_single_expression(heap, stmt.initial_expr.upcast()); + } + }, Statement::If(stmt) => cur_frame.prepare_single_expression(heap, stmt.test), Statement::While(stmt) => cur_frame.prepare_single_expression(heap, stmt.test), Statement::Return(stmt) => { diff --git a/src/protocol/parser/pass_definitions.rs b/src/protocol/parser/pass_definitions.rs index c6b2f994762f6f61637e05f7c3fd29e8198a074b..66ed638cc16da38f731d12225efd5d7c0a7da217 100644 --- a/src/protocol/parser/pass_definitions.rs +++ b/src/protocol/parser/pass_definitions.rs @@ -367,7 +367,6 @@ impl PassDefinitions { scope_node: ScopeNode::new_invalid(), first_unique_id_in_scope: -1, next_unique_id_in_scope: -1, - relative_pos_in_parent: 0, locals: Vec::new(), labels: Vec::new(), next: StatementId::new_invalid(), @@ -456,6 +455,19 @@ impl PassDefinitions { let fork_stmt = &mut ctx.heap[id]; fork_stmt.end_fork = end_fork; + } else if ident == KW_STMT_SELECT { + let id = self.consume_select_statement(module, iter, ctx)?; + section.push(id.upcast()); + + let end_select = ctx.heap.alloc_end_select_statement(|this| EndSelectStatement{ + this, + start_select: id, + next: StatementId::new_invalid(), + }); + section.push(end_select.upcast()); + + let select_stmt = &mut ctx.heap[id]; + select_stmt.end_select = end_select; } else if ident == KW_STMT_RETURN { let id = self.consume_return_statement(module, iter, ctx)?; section.push(id.upcast()); @@ -474,9 +486,9 @@ impl PassDefinitions { // Two fallback possibilities: the first one is a memory // declaration, the other one is to parse it as a normal // expression. This is a bit ugly. - if let Some((memory_stmt_id, assignment_stmt_id)) = self.maybe_consume_memory_statement(module, iter, ctx)? { + if let Some(memory_stmt_id) = self.maybe_consume_memory_statement_without_semicolon(module, iter, ctx)? { + consume_token(&module.source, iter, TokenKind::SemiColon)?; section.push(memory_stmt_id.upcast().upcast()); - section.push(assignment_stmt_id.upcast()); } else { let id = self.consume_expression_statement(module, iter, ctx)?; section.push(id.upcast()); @@ -484,9 +496,9 @@ impl PassDefinitions { } } else if next == TokenKind::OpenParen { // Same as above: memory statement or normal expression - if let Some((memory_stmt_id, assignment_stmt_id)) = self.maybe_consume_memory_statement(module, iter, ctx)? { + if let Some(memory_stmt_id) = self.maybe_consume_memory_statement_without_semicolon(module, iter, ctx)? { + consume_token(&module.source, iter, TokenKind::SemiColon)?; section.push(memory_stmt_id.upcast().upcast()); - section.push(assignment_stmt_id.upcast()); } else { let id = self.consume_expression_statement(module, iter, ctx)?; section.push(id.upcast()); @@ -534,7 +546,6 @@ impl PassDefinitions { scope_node: ScopeNode::new_invalid(), first_unique_id_in_scope: -1, next_unique_id_in_scope: -1, - relative_pos_in_parent: 0, locals: Vec::new(), labels: Vec::new(), next: StatementId::new_invalid(), @@ -611,7 +622,7 @@ impl PassDefinitions { this, span: break_span, label, - target: None, + target: EndWhileStatementId::new_invalid(), })) } @@ -630,7 +641,7 @@ impl PassDefinitions { this, span: continue_span, label, - target: None + target: WhileStatementId::new_invalid(), })) } @@ -671,6 +682,53 @@ impl PassDefinitions { })) } + fn consume_select_statement( + &mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx + ) -> Result { + let select_span = consume_exact_ident(&module.source, iter, KW_STMT_SELECT)?; + consume_token(&module.source, iter, TokenKind::OpenCurly)?; + + let mut cases = Vec::new(); + let mut next = iter.next(); + + while Some(TokenKind::CloseCurly) != next { + let guard = match self.maybe_consume_memory_statement_without_semicolon(module, iter, ctx)? { + Some(guard_mem_stmt) => guard_mem_stmt.upcast().upcast(), + None => { + let start_pos = iter.last_valid_pos(); + let expr = self.consume_expression(module, iter, ctx)?; + let end_pos = iter.last_valid_pos(); + + let guard_expr_stmt = ctx.heap.alloc_expression_statement(|this| ExpressionStatement{ + this, + span: InputSpan::from_positions(start_pos, end_pos), + expression: expr, + next: StatementId::new_invalid(), + }); + + guard_expr_stmt.upcast() + }, + }; + consume_token(&module.source, iter, TokenKind::ArrowRight)?; + let block = self.consume_block_or_wrapped_statement(module, iter, ctx)?; + cases.push(SelectCase{ + guard, block, + involved_ports: Vec::with_capacity(1) + }); + + next = iter.next(); + } + + consume_token(&module.source, iter, TokenKind::CloseCurly)?; + + Ok(ctx.heap.alloc_select_statement(|this| SelectStatement{ + this, + span: select_span, + cases, + end_select: EndSelectStatementId::new_invalid(), + })) + } + fn consume_return_statement( &mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx ) -> Result { @@ -707,7 +765,7 @@ impl PassDefinitions { this, span: goto_span, label, - target: None + target: LabeledStatementId::new_invalid(), })) } @@ -866,9 +924,14 @@ impl PassDefinitions { Ok(()) } - fn maybe_consume_memory_statement( + /// Attempts to consume a memory statement (a statement along the lines of + /// `type var_name = initial_expr`). Will return `Ok(None)` if it didn't + /// seem like there was a memory statement, `Ok(Some(...))` if there was + /// one, and `Err(...)` if its reasonable to assume that there was a memory + /// statement, but we failed to parse it. + fn maybe_consume_memory_statement_without_semicolon( &mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx - ) -> Result, ParseError> { + ) -> Result, ParseError> { // This is a bit ugly. It would be nicer if we could somehow // consume the expression with a type hint if we do get a valid // type, but we don't get an identifier following it @@ -888,12 +951,10 @@ impl PassDefinitions { let memory_span = InputSpan::from_positions(parser_type.full_span.begin, identifier.span.end); let assign_span = consume_token(&module.source, iter, TokenKind::Equal)?; - let initial_expr_begin_pos = iter.last_valid_pos(); let initial_expr_id = self.consume_expression(module, iter, ctx)?; let initial_expr_end_pos = iter.last_valid_pos(); - consume_token(&module.source, iter, TokenKind::SemiColon)?; - // Allocate the memory statement with the variable + // Create the AST variable let local_id = ctx.heap.alloc_variable(|this| Variable{ this, kind: VariableKind::Local, @@ -902,18 +963,13 @@ impl PassDefinitions { relative_pos_in_block: 0, unique_id_in_scope: -1, }); - let memory_stmt_id = ctx.heap.alloc_memory_statement(|this| MemoryStatement{ - this, - span: memory_span, - variable: local_id, - next: StatementId::new_invalid() - }); - // Allocate the initial assignment + // Create the initial assignment expression + // Note: we set the initial variable declaration here let variable_expr_id = ctx.heap.alloc_variable_expression(|this| VariableExpression{ this, identifier, - declaration: None, + declaration: Some(local_id), used_as_binding_target: false, parent: ExpressionParent::None, unique_id_in_definition: -1, @@ -928,14 +984,17 @@ impl PassDefinitions { parent: ExpressionParent::None, unique_id_in_definition: -1, }); - let assignment_stmt_id = ctx.heap.alloc_expression_statement(|this| ExpressionStatement{ + + // Put both together in the memory statement + let memory_stmt_id = ctx.heap.alloc_memory_statement(|this| MemoryStatement{ this, - span: InputSpan::from_positions(initial_expr_begin_pos, initial_expr_end_pos), - expression: assignment_expr_id.upcast(), - next: StatementId::new_invalid(), + span: memory_span, + variable: local_id, + initial_expr: assignment_expr_id, + next: StatementId::new_invalid() }); - return Ok(Some((memory_stmt_id, assignment_stmt_id))) + return Ok(Some(memory_stmt_id)); } } diff --git a/src/protocol/parser/pass_typing.rs b/src/protocol/parser/pass_typing.rs index 4eb0f85958973c5846df37d892465fb89e1b5321..e31e13c896efe7558d80c95de258c845868faeb9 100644 --- a/src/protocol/parser/pass_typing.rs +++ b/src/protocol/parser/pass_typing.rs @@ -1094,11 +1094,16 @@ impl Visitor for PassTyping { fn visit_local_memory_stmt(&mut self, ctx: &mut Ctx, id: MemoryStatementId) -> VisitorResult { let memory_stmt = &ctx.heap[id]; + let initial_expr_id = memory_stmt.initial_expr; + // Setup memory statement inference let local = &ctx.heap[memory_stmt.variable]; let var_type = self.determine_inference_type_from_parser_type_elements(&local.parser_type.elements, true); self.var_types.insert(memory_stmt.variable, VarData::new_local(var_type)); + // Process the initial value + self.visit_assignment_expr(ctx, initial_expr_id)?; + Ok(()) } @@ -1170,6 +1175,30 @@ impl Visitor for PassTyping { Ok(()) } + fn visit_select_stmt(&mut self, ctx: &mut Ctx, id: SelectStatementId) -> VisitorResult { + let select_stmt = &ctx.heap[id]; + + let mut section = self.stmt_buffer.start_section(); + let num_cases = select_stmt.cases.len(); + + for case in &select_stmt.cases { + section.push(case.guard); + section.push(case.block.upcast()); + } + + for case_index in 0..num_cases { + let base_index = 2 * case_index; + let guard_stmt_id = section[base_index ]; + let block_stmt_id = section[base_index + 1]; + + self.visit_stmt(ctx, guard_stmt_id)?; + self.visit_stmt(ctx, block_stmt_id)?; + } + section.forget(); + + Ok(()) + } + fn visit_return_stmt(&mut self, ctx: &mut Ctx, id: ReturnStatementId) -> VisitorResult { let return_stmt = &ctx.heap[id]; debug_assert_eq!(return_stmt.expressions.len(), 1); @@ -3276,7 +3305,7 @@ impl PassTyping { EP::None => // Should have been set by linker unreachable!(), - EP::ExpressionStmt(_) => + EP::Memory(_) | EP::ExpressionStmt(_) => // Determined during type inference InferenceType::new(false, false, vec![ITP::Unknown]), EP::Expression(parent_id, idx_in_parent) => { diff --git a/src/protocol/parser/pass_validation_linking.rs b/src/protocol/parser/pass_validation_linking.rs index df9d4551039e33aced60cb296859e92fd23c9d79..dd755059ac902761a6752d00a6b036909d572d6a 100644 --- a/src/protocol/parser/pass_validation_linking.rs +++ b/src/protocol/parser/pass_validation_linking.rs @@ -69,6 +69,13 @@ impl DefinitionType { } } +struct ControlFlowStatement { + in_sync: SynchronousStatementId, + in_while: WhileStatementId, + in_scope: Scope, + statement: StatementId, // of 'break', 'continue' or 'goto' +} + /// This particular visitor will go through the entire AST in a recursive manner /// and check if all statements and expressions are legal (e.g. no "return" /// statements in component definitions), and will link certain AST nodes to @@ -79,11 +86,16 @@ impl DefinitionType { /// the linking of function calls and component instantiations will be checked /// and linked to the appropriate definitions, but the return types and/or /// arguments will not be checked for validity. +/// +/// The main idea is, because we're visiting nodes in a tree, to do as much as +/// we can while we have the memory in cache. pub(crate) struct PassValidationLinking { // Traversal state, all valid IDs if inside a certain AST element. Otherwise // `id.is_invalid()` returns true. in_sync: SynchronousStatementId, in_while: WhileStatementId, // to resolve labeled continue/break + in_select_guard: SelectStatementId, // for detection/rejection of builtin calls + in_select_arm: u32, in_test_expr: StatementId, // wrapping if/while stmt id in_binding_expr: BindingExpressionId, // to resolve variable expressions in_binding_expr_lhs: bool, @@ -99,8 +111,10 @@ pub(crate) struct PassValidationLinking { // used for the error's position must_be_assignable: Option, // Keeping track of relative positions and unique IDs. - relative_pos_in_block: u32, // of statements: to determine when variables are visible + relative_pos_in_block: i32, // of statements: to determine when variables are visible next_expr_index: i32, // to arrive at a unique ID for all expressions within a definition + // Control flow statements that require label resolving + control_flow_stmts: Vec, // Various temporary buffers for traversal. Essentially working around // Rust's borrowing rules since it cannot understand we're modifying AST // members but not the AST container. @@ -115,16 +129,19 @@ impl PassValidationLinking { Self{ in_sync: SynchronousStatementId::new_invalid(), in_while: WhileStatementId::new_invalid(), + in_select_guard: SelectStatementId::new_invalid(), + in_select_arm: 0, in_test_expr: StatementId::new_invalid(), in_binding_expr: BindingExpressionId::new_invalid(), in_binding_expr_lhs: false, - cur_scope: Scope::Definition(DefinitionId::new_invalid()), + cur_scope: Scope::new_invalid(), prev_stmt: StatementId::new_invalid(), expr_parent: ExpressionParent::None, def_type: DefinitionType::Function(FunctionDefinitionId::new_invalid()), must_be_assignable: None, relative_pos_in_block: 0, next_expr_index: 0, + control_flow_stmts: Vec::with_capacity(32), variable_buffer: ScopedBuffer::with_capacity(128), definition_buffer: ScopedBuffer::with_capacity(128), statement_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAPACITY), @@ -135,16 +152,18 @@ impl PassValidationLinking { fn reset_state(&mut self) { self.in_sync = SynchronousStatementId::new_invalid(); self.in_while = WhileStatementId::new_invalid(); + self.in_select_guard = SelectStatementId::new_invalid(); self.in_test_expr = StatementId::new_invalid(); self.in_binding_expr = BindingExpressionId::new_invalid(); self.in_binding_expr_lhs = false; - self.cur_scope = Scope::Definition(DefinitionId::new_invalid()); + self.cur_scope = Scope::new_invalid(); self.def_type = DefinitionType::Function(FunctionDefinitionId::new_invalid()); self.prev_stmt = StatementId::new_invalid(); self.expr_parent = ExpressionParent::None; self.must_be_assignable = None; self.relative_pos_in_block = 0; - self.next_expr_index = 0 + self.next_expr_index = 0; + self.control_flow_stmts.clear(); } } @@ -212,6 +231,7 @@ impl Visitor for PassValidationLinking { // to each of the locals in the procedure. ctx.heap[id].num_expressions_in_body = self.next_expr_index; self.visit_definition_and_assign_local_ids(ctx, id.upcast()); + self.resolve_pending_control_flow_targets(ctx)?; Ok(()) } @@ -242,6 +262,7 @@ impl Visitor for PassValidationLinking { // to each of the locals in the procedure. ctx.heap[id].num_expressions_in_body = self.next_expr_index; self.visit_definition_and_assign_local_ids(ctx, id.upcast()); + self.resolve_pending_control_flow_targets(ctx)?; Ok(()) } @@ -251,23 +272,65 @@ impl Visitor for PassValidationLinking { //-------------------------------------------------------------------------- fn visit_block_stmt(&mut self, ctx: &mut Ctx, id: BlockStatementId) -> VisitorResult { - self.visit_block_stmt_with_hint(ctx, id, None) + let old_scope = self.push_statement_scope(ctx, Scope::Regular(id)); + + // Set end of block + let block_stmt = &ctx.heap[id]; + let end_block_id = block_stmt.end_block; + + // Copy statement IDs into buffer + + // Traverse statements in block + let statement_section = self.statement_buffer.start_section_initialized(&block_stmt.statements); + assign_and_replace_next_stmt!(self, ctx, id.upcast()); + + for stmt_idx in 0..statement_section.len() { + self.relative_pos_in_block = stmt_idx as i32; + self.visit_stmt(ctx, statement_section[stmt_idx])?; + } + + statement_section.forget(); + assign_and_replace_next_stmt!(self, ctx, end_block_id.upcast()); + + self.pop_statement_scope(old_scope); + Ok(()) } fn visit_local_memory_stmt(&mut self, ctx: &mut Ctx, id: MemoryStatementId) -> VisitorResult { + let stmt = &ctx.heap[id]; + let expr_id = stmt.initial_expr; + let variable_id = stmt.variable; + + self.checked_add_local(ctx, self.cur_scope, self.relative_pos_in_block, variable_id)?; + assign_and_replace_next_stmt!(self, ctx, id.upcast().upcast()); + debug_assert_eq!(self.expr_parent, ExpressionParent::None); + self.expr_parent = ExpressionParent::Memory(id); + self.visit_assignment_expr(ctx, expr_id)?; + self.expr_parent = ExpressionParent::None; + Ok(()) } fn visit_local_channel_stmt(&mut self, ctx: &mut Ctx, id: ChannelStatementId) -> VisitorResult { + let stmt = &ctx.heap[id]; + let from_id = stmt.from; + let to_id = stmt.to; + + self.checked_add_local(ctx, self.cur_scope, self.relative_pos_in_block, from_id)?; + self.checked_add_local(ctx, self.cur_scope, self.relative_pos_in_block, to_id)?; + assign_and_replace_next_stmt!(self, ctx, id.upcast().upcast()); Ok(()) } fn visit_labeled_stmt(&mut self, ctx: &mut Ctx, id: LabeledStatementId) -> VisitorResult { - let body_id = ctx.heap[id].body; - self.visit_stmt(ctx, body_id)?; + let stmt = &ctx.heap[id]; + let body_id = stmt.body; + + self.checked_add_label(ctx, self.relative_pos_in_block, self.in_sync, id)?; + self.visit_stmt(ctx, body_id)?; Ok(()) } @@ -339,33 +402,25 @@ impl Visitor for PassValidationLinking { } fn visit_break_stmt(&mut self, ctx: &mut Ctx, id: BreakStatementId) -> VisitorResult { - // Resolve break target - let target_end_while = { - let stmt = &ctx.heap[id]; - let target_while_id = self.resolve_break_or_continue_target(ctx, stmt.span, &stmt.label)?; - let target_while = &ctx.heap[target_while_id]; - debug_assert!(!target_while.end_while.is_invalid()); - - target_while.end_while - }; - + self.control_flow_stmts.push(ControlFlowStatement{ + in_sync: self.in_sync, + in_while: self.in_while, + in_scope: self.cur_scope, + statement: id.upcast() + }); assign_then_erase_next_stmt!(self, ctx, id.upcast()); - let stmt = &mut ctx.heap[id]; - stmt.target = Some(target_end_while); Ok(()) } fn visit_continue_stmt(&mut self, ctx: &mut Ctx, id: ContinueStatementId) -> VisitorResult { - // Resolve continue target - let target_while_id = { - let stmt = &ctx.heap[id]; - self.resolve_break_or_continue_target(ctx, stmt.span, &stmt.label)? - }; - + self.control_flow_stmts.push(ControlFlowStatement{ + in_sync: self.in_sync, + in_while: self.in_while, + in_scope: self.cur_scope, + statement: id.upcast() + }); assign_then_erase_next_stmt!(self, ctx, id.upcast()); - let stmt = &mut ctx.heap[id]; - stmt.target = Some(target_while_id); Ok(()) } @@ -395,10 +450,15 @@ impl Visitor for PassValidationLinking { // Synchronous statement implicitly moves to its block assign_then_erase_next_stmt!(self, ctx, id.upcast()); + // Visit block statement. Note that we explicitly push the scope here + // (and the `visit_block_stmt` will also push, but without effect) to + // ensure the scope contains the sync ID. let sync_body = ctx.heap[id].body; debug_assert!(self.in_sync.is_invalid()); self.in_sync = id; - self.visit_block_stmt_with_hint(ctx, sync_body, Some(id))?; + let old_scope = self.push_statement_scope(ctx, Scope::Synchronous(id, sync_body)); + self.visit_block_stmt(ctx, sync_body)?; + self.pop_statement_scope(old_scope); assign_and_replace_next_stmt!(self, ctx, end_sync_id.upcast()); self.in_sync = SynchronousStatementId::new_invalid(); @@ -436,6 +496,69 @@ impl Visitor for PassValidationLinking { Ok(()) } + fn visit_select_stmt(&mut self, ctx: &mut Ctx, id: SelectStatementId) -> VisitorResult { + let select_stmt = &ctx.heap[id]; + let end_select_id = select_stmt.end_select; + + // Select statements may only occur inside sync blocks + if self.in_sync.is_invalid() { + return Err(ParseError::new_error_str_at_span( + &ctx.module().source, select_stmt.span, + "select statements may only occur inside sync blocks" + )); + } + + if !self.def_type.is_primitive() { + return Err(ParseError::new_error_str_at_span( + &ctx.module().source, select_stmt.span, + "select statements may only be used in primitive components" + )); + } + + // Visit the various arms in the select block + let mut case_stmt_ids = self.statement_buffer.start_section(); + let num_cases = select_stmt.cases.len(); + for case in &select_stmt.cases { + // Note: we add both to the buffer, retrieve them later in indexed + // fashion + case_stmt_ids.push(case.guard); + case_stmt_ids.push(case.block.upcast()); + } + + assign_then_erase_next_stmt!(self, ctx, id.upcast()); + + for idx in 0..num_cases { + let base_idx = 2 * idx; + let guard_id = case_stmt_ids[base_idx ]; + let arm_block_id = case_stmt_ids[base_idx + 1]; + debug_assert_eq!(ctx.heap[arm_block_id].as_block().this.upcast(), arm_block_id); // backwards way of saying arm_block_id is a BlockStatementId + let arm_block_id = BlockStatementId(arm_block_id); + + // The guard statement ends up belonging to the block statement + // following the arm. The reason we parse it separately is to + // extract all of the "get" calls. + let old_scope = self.push_statement_scope(ctx, Scope::Regular(arm_block_id)); + + // Visit the guard of this arm + debug_assert!(self.in_select_guard.is_invalid()); + self.in_select_guard = id; + self.in_select_arm = idx as u32; + self.visit_stmt(ctx, guard_id)?; + self.in_select_guard = SelectStatementId::new_invalid(); + + // Visit the code associated with the guard + self.visit_block_stmt(ctx, arm_block_id)?; + self.pop_statement_scope(old_scope); + + // Link up last statement in block to EndSelect + assign_then_erase_next_stmt!(self, ctx, end_select_id.upcast()); + } + + self.in_select_guard = SelectStatementId::new_invalid(); + self.prev_stmt = end_select_id.upcast(); + Ok(()) + } + fn visit_return_stmt(&mut self, ctx: &mut Ctx, id: ReturnStatementId) -> VisitorResult { // Check if "return" occurs within a function let stmt = &ctx.heap[id]; @@ -458,24 +581,12 @@ impl Visitor for PassValidationLinking { } fn visit_goto_stmt(&mut self, ctx: &mut Ctx, id: GotoStatementId) -> VisitorResult { - let target_id = self.find_label(ctx, &ctx.heap[id].label)?; - ctx.heap[id].target = Some(target_id); - - let target = &ctx.heap[target_id]; - if self.in_sync != target.in_sync { - // We can only goto the current scope or outer scopes. Because - // nested sync statements are not allowed we must be inside a sync - // statement. - debug_assert!(!self.in_sync.is_invalid()); - let goto_stmt = &ctx.heap[id]; - let sync_stmt = &ctx.heap[self.in_sync]; - return Err( - ParseError::new_error_str_at_span(&ctx.module().source, goto_stmt.span, "goto may not escape the surrounding synchronous block") - .with_info_str_at_span(&ctx.module().source, target.label.span, "this is the target of the goto statement") - .with_info_str_at_span(&ctx.module().source, sync_stmt.span, "which will jump past this statement") - ); - } - + self.control_flow_stmts.push(ControlFlowStatement{ + in_sync: self.in_sync, + in_while: self.in_while, + in_scope: self.cur_scope, + statement: id.upcast(), + }); assign_then_erase_next_stmt!(self, ctx, id.upcast()); Ok(()) @@ -530,6 +641,7 @@ impl Visitor for PassValidationLinking { // code (mainly typechecking), we disallow nested use in expressions match self.expr_parent { // Look at us: lying through our teeth while providing error messages. + ExpressionParent::Memory(_) => {}, ExpressionParent::ExpressionStmt(_) => {}, _ => { let assignment_span = assignment_expr.full_span; @@ -1044,7 +1156,7 @@ impl Visitor for PassValidationLinking { } fn visit_call_expr(&mut self, ctx: &mut Ctx, id: CallExpressionId) -> VisitorResult { - let call_expr = &mut ctx.heap[id]; + let call_expr = &ctx.heap[id]; if let Some(span) = self.must_be_assignable { return Err(ParseError::new_error_str_at_span( @@ -1054,59 +1166,43 @@ impl Visitor for PassValidationLinking { // Check whether the method is allowed to be called within the code's // context (in sync, definition type, etc.) - let mut expected_wrapping_new_stmt = false; - match &mut call_expr.method { + let mut expecting_wrapping_new_stmt = false; + let mut expecting_primitive_def = false; + let mut expecting_wrapping_sync_stmt = false; + let mut expecting_no_select_stmt = false; + + match call_expr.method { Method::Get => { - if !self.def_type.is_primitive() { - let call_span = call_expr.func_span; - return Err(ParseError::new_error_str_at_span( - &ctx.module().source, call_span, - "a call to 'get' may only occur in primitive component definitions" - )); - } - if self.in_sync.is_invalid() { - let call_span = call_expr.func_span; - return Err(ParseError::new_error_str_at_span( - &ctx.module().source, call_span, - "a call to 'get' may only occur inside synchronous blocks" - )); + expecting_primitive_def = true; + expecting_wrapping_sync_stmt = true; + if !self.in_select_guard.is_invalid() { + // In a select guard. Take the argument (i.e. the port we're + // retrieving from) and add it to the list of involved ports + // of the guard + if call_expr.arguments.len() == 1 { + // We're checking the number of arguments later, for now + // assume it is correct. + let argument = call_expr.arguments[0]; + let select_stmt = &mut ctx.heap[self.in_select_guard]; + let select_case = &mut select_stmt.cases[self.in_select_arm as usize]; + select_case.involved_ports.push((id, argument)); + } } }, Method::Put => { - if !self.def_type.is_primitive() { - let call_span = call_expr.func_span; - return Err(ParseError::new_error_str_at_span( - &ctx.module().source, call_span, - "a call to 'put' may only occur in primitive component definitions" - )); - } - if self.in_sync.is_invalid() { - let call_span = call_expr.func_span; - return Err(ParseError::new_error_str_at_span( - &ctx.module().source, call_span, - "a call to 'put' may only occur inside synchronous blocks" - )); - } + expecting_primitive_def = true; + expecting_wrapping_sync_stmt = true; + expecting_no_select_stmt = true; }, Method::Fires => { - if !self.def_type.is_primitive() { - let call_span = call_expr.func_span; - return Err(ParseError::new_error_str_at_span( - &ctx.module().source, call_span, - "a call to 'fires' may only occur in primitive component definitions" - )); - } - if self.in_sync.is_invalid() { - let call_span = call_expr.func_span; - return Err(ParseError::new_error_str_at_span( - &ctx.module().source, call_span, - "a call to 'fires' may only occur inside synchronous blocks" - )); - } + expecting_primitive_def = true; + expecting_wrapping_sync_stmt = true; }, Method::Create => {}, Method::Length => {}, Method::Assert => { + expecting_wrapping_sync_stmt = true; + expecting_no_select_stmt = true; if self.def_type.is_function() { let call_span = call_expr.func_span; return Err(ParseError::new_error_str_at_span( @@ -1114,22 +1210,53 @@ impl Visitor for PassValidationLinking { "assert statement may only occur in components" )); } - if self.in_sync.is_invalid() { - let call_span = call_expr.func_span; - return Err(ParseError::new_error_str_at_span( - &ctx.module().source, call_span, - "assert statements may only occur inside synchronous blocks" - )); - } }, Method::Print => {}, Method::UserFunction => {}, Method::UserComponent => { - expected_wrapping_new_stmt = true; + expecting_wrapping_new_stmt = true; }, } - if expected_wrapping_new_stmt { + let call_expr = &mut ctx.heap[id]; + + fn get_span_and_name<'a>(ctx: &'a Ctx, id: CallExpressionId) -> (InputSpan, String) { + let call = &ctx.heap[id]; + let span = call.func_span; + let name = String::from_utf8_lossy(ctx.module().source.section_at_span(span)).to_string(); + return (span, name); + } + if expecting_primitive_def { + if !self.def_type.is_primitive() { + let (call_span, func_name) = get_span_and_name(ctx, id); + return Err(ParseError::new_error_at_span( + &ctx.module().source, call_span, + format!("a call to '{}' may only occur in primitive component definitions", func_name) + )); + } + } + + if expecting_wrapping_sync_stmt { + if self.in_sync.is_invalid() { + let (call_span, func_name) = get_span_and_name(ctx, id); + return Err(ParseError::new_error_at_span( + &ctx.module().source, call_span, + format!("a call to '{}' may only occur inside synchronous blocks", func_name) + )) + } + } + + if expecting_no_select_stmt { + if !self.in_select_guard.is_invalid() { + let (call_span, func_name) = get_span_and_name(ctx, id); + return Err(ParseError::new_error_at_span( + &ctx.module().source, call_span, + format!("a call to '{}' may not occur in a select statement's guard", func_name) + )); + } + } + + if expecting_wrapping_new_stmt { if !self.expr_parent.is_new() { let call_span = call_expr.func_span; return Err(ParseError::new_error_str_at_span( @@ -1191,96 +1318,103 @@ impl Visitor for PassValidationLinking { fn visit_variable_expr(&mut self, ctx: &mut Ctx, id: VariableExpressionId) -> VisitorResult { let var_expr = &ctx.heap[id]; - let (variable_id, is_binding_target) = match self.find_variable(ctx, self.relative_pos_in_block, &var_expr.identifier) { - Ok(variable_id) => { - // Regular variable - (variable_id, false) - }, - Err(()) => { - // Couldn't find variable, but if we're in a binding expression, - // then this may be the thing we're binding to. - if self.in_binding_expr.is_invalid() || !self.in_binding_expr_lhs { - return Err(ParseError::new_error_str_at_span( - &ctx.module().source, var_expr.identifier.span, "unresolved variable" - )); - } + // Check if declaration was already resolved (this occurs for the + // variable expr that is on the LHS of the assignment expr that is + // associated with a variable declaration) + let mut variable_id = var_expr.declaration; + let mut is_binding_target = false; - // This is a binding variable, but it may only appear in very - // specific locations. - let is_valid_binding = match self.expr_parent { - ExpressionParent::Expression(expr_id, idx) => { - match &ctx.heap[expr_id] { - Expression::Binding(_binding_expr) => { - // Nested binding is disallowed, and because of - // the check above we know we're directly at the - // LHS of the binding expression - debug_assert_eq!(_binding_expr.this, self.in_binding_expr); - debug_assert_eq!(idx, 0); - true - } - Expression::Literal(lit_expr) => { - // Only struct, unions and arrays can have - // subexpressions, so we're always fine - if cfg!(debug_assertions) { - match lit_expr.value { - Literal::Struct(_) | Literal::Union(_) | Literal::Array(_) | Literal::Tuple(_) => {}, - _ => unreachable!(), - } - } + // Otherwise try to find it + if variable_id.is_none() { + variable_id = self.find_variable(ctx, self.relative_pos_in_block, &var_expr.identifier); + } - true - }, - _ => false, + // Otherwise try to see if is a variable introduced by a binding expr + let variable_id = if let Some(variable_id) = variable_id { + variable_id + } else { + if self.in_binding_expr.is_invalid() || !self.in_binding_expr_lhs { + return Err(ParseError::new_error_str_at_span( + &ctx.module().source, var_expr.identifier.span, "unresolved variable" + )); + } + + // This is a binding variable, but it may only appear in very + // specific locations. + let is_valid_binding = match self.expr_parent { + ExpressionParent::Expression(expr_id, idx) => { + match &ctx.heap[expr_id] { + Expression::Binding(_binding_expr) => { + // Nested binding is disallowed, and because of + // the check above we know we're directly at the + // LHS of the binding expression + debug_assert_eq!(_binding_expr.this, self.in_binding_expr); + debug_assert_eq!(idx, 0); + true } - }, - _ => { - false - } - }; + Expression::Literal(lit_expr) => { + // Only struct, unions, tuples and arrays can + // have subexpressions, so we're always fine + if cfg!(debug_assertions) { + match lit_expr.value { + Literal::Struct(_) | Literal::Union(_) | Literal::Array(_) | Literal::Tuple(_) => {}, + _ => unreachable!(), + } + } - if !is_valid_binding { - let binding_expr = &ctx.heap[self.in_binding_expr]; - return Err(ParseError::new_error_str_at_span( - &ctx.module().source, var_expr.identifier.span, - "illegal location for binding variable: binding variables may only be nested under a binding expression, or a struct, union or array literal" - ).with_info_at_span( - &ctx.module().source, binding_expr.operator_span, format!( - "'{}' was interpreted as a binding variable because the variable is not declared and it is nested under this binding expression", - var_expr.identifier.value.as_str() - ) - )); + true + }, + _ => false, + } + }, + _ => { + false } + }; - // By now we know that this is a valid binding expression. Given - // that a binding expression must be nested under an if/while - // statement, we now add the variable to the (implicit) block - // statement following the if/while statement. - let bound_identifier = var_expr.identifier.clone(); - let bound_variable_id = ctx.heap.alloc_variable(|this| Variable{ - this, - kind: VariableKind::Binding, - parser_type: ParserType{ - elements: vec![ParserTypeElement{ - element_span: bound_identifier.span, - variant: ParserTypeVariant::Inferred - }], - full_span: bound_identifier.span - }, - identifier: bound_identifier, - relative_pos_in_block: 0, - unique_id_in_scope: -1, - }); + if !is_valid_binding { + let binding_expr = &ctx.heap[self.in_binding_expr]; + return Err(ParseError::new_error_str_at_span( + &ctx.module().source, var_expr.identifier.span, + "illegal location for binding variable: binding variables may only be nested under a binding expression, or a struct, union or array literal" + ).with_info_at_span( + &ctx.module().source, binding_expr.operator_span, format!( + "'{}' was interpreted as a binding variable because the variable is not declared and it is nested under this binding expression", + var_expr.identifier.value.as_str() + ) + )); + } - let body_stmt_id = match &ctx.heap[self.in_test_expr] { - Statement::If(stmt) => stmt.true_body, - Statement::While(stmt) => stmt.body, - _ => unreachable!(), - }; - let body_scope = Scope::Regular(body_stmt_id); - self.checked_at_single_scope_add_local(ctx, body_scope, 0, bound_variable_id)?; + // By now we know that this is a valid binding expression. Given + // that a binding expression must be nested under an if/while + // statement, we now add the variable to the (implicit) block + // statement following the if/while statement. + let bound_identifier = var_expr.identifier.clone(); + let bound_variable_id = ctx.heap.alloc_variable(|this| Variable { + this, + kind: VariableKind::Binding, + parser_type: ParserType { + elements: vec![ParserTypeElement { + element_span: bound_identifier.span, + variant: ParserTypeVariant::Inferred + }], + full_span: bound_identifier.span + }, + identifier: bound_identifier, + relative_pos_in_block: 0, + unique_id_in_scope: -1, + }); + + let body_stmt_id = match &ctx.heap[self.in_test_expr] { + Statement::If(stmt) => stmt.true_body, + Statement::While(stmt) => stmt.body, + _ => unreachable!(), + }; + let body_scope = Scope::Regular(body_stmt_id); + self.checked_at_single_scope_add_local(ctx, body_scope, -1, bound_variable_id)?; // add at -1 such that first statement can access - (bound_variable_id, true) - } + is_binding_target = true; + bound_variable_id }; let var_expr = &mut ctx.heap[id]; @@ -1299,99 +1433,48 @@ impl PassValidationLinking { // Special traversal //-------------------------------------------------------------------------- - fn visit_block_stmt_with_hint(&mut self, ctx: &mut Ctx, id: BlockStatementId, hint: Option) -> VisitorResult { - // Set parent scope and relative position in the parent scope. Remember - // these values to set them back to the old values when we're done with - // the traversal of the block's statements. + /// Pushes a new scope associated with a particular statement. If that + /// statement already has an associated scope (i.e. scope associated with + /// sync statement or select statement's arm) then we won't do anything. + /// In all cases the caller must call `pop_statement_scope` with the scope + /// and relative scope position returned by this function. + fn push_statement_scope(&mut self, ctx: &mut Ctx, new_scope: Scope) -> (Scope, i32) { let old_scope = self.cur_scope.clone(); - let new_scope = match hint { - Some(sync_id) => Scope::Synchronous((sync_id, id)), - None => Scope::Regular(id), + debug_assert!(new_scope.is_block()); // never call for Definition scope + let is_new_block = if old_scope.is_block() { + old_scope.to_block() != new_scope.to_block() + } else { + true }; - match old_scope { - Scope::Definition(_def_id) => { - // Don't do anything. Block is implicitly a child of a - // definition scope. - if cfg!(debug_assertions) { - match &ctx.heap[_def_id] { - Definition::Function(proc_def) => debug_assert_eq!(proc_def.body, id), - Definition::Component(proc_def) => debug_assert_eq!(proc_def.body, id), - _ => unreachable!(), - } - } - }, - Scope::Regular(block_id) | Scope::Synchronous((_, block_id)) => { - let parent_block = &mut ctx.heap[block_id]; - parent_block.scope_node.nested.push(new_scope); - } + if !is_new_block { + // No need to push, but still return old scope, we pretend like we + // replaced it. + debug_assert!(!ctx.heap[new_scope.to_block()].scope_node.parent.is_invalid()); + return (old_scope, self.relative_pos_in_block); + } + + // This is a new block, so link it up + if old_scope.is_block() { + let parent_block = &mut ctx.heap[old_scope.to_block()]; + parent_block.scope_node.nested.push(new_scope); } self.cur_scope = new_scope; - let body = &mut ctx.heap[id]; - body.scope_node.parent = old_scope; - body.relative_pos_in_parent = self.relative_pos_in_block; - let end_block_id = body.end_block; + let cur_block = &mut ctx.heap[new_scope.to_block()]; + cur_block.scope_node.parent = old_scope; + cur_block.scope_node.relative_pos_in_parent = self.relative_pos_in_block; let old_relative_pos = self.relative_pos_in_block; + self.relative_pos_in_block = -1; - // Copy statement IDs into buffer - let statement_section = self.statement_buffer.start_section_initialized(&body.statements); - - // Perform the breadth-first pass. Its main purpose is to find labeled - // statements such that we can find the `goto`-targets immediately when - // performing the depth pass - for stmt_idx in 0..statement_section.len() { - self.relative_pos_in_block = stmt_idx as u32; - self.visit_statement_for_locals_labels_and_in_sync(ctx, self.relative_pos_in_block, statement_section[stmt_idx])?; - } - - // Perform the depth-first traversal - assign_and_replace_next_stmt!(self, ctx, id.upcast()); - for stmt_idx in 0..statement_section.len() { - self.relative_pos_in_block = stmt_idx as u32; - self.visit_stmt(ctx, statement_section[stmt_idx])?; - } - assign_and_replace_next_stmt!(self, ctx, end_block_id.upcast()); - - self.cur_scope = old_scope; - self.relative_pos_in_block = old_relative_pos; - statement_section.forget(); - - Ok(()) + return (old_scope, old_relative_pos) } - fn visit_statement_for_locals_labels_and_in_sync(&mut self, ctx: &mut Ctx, relative_pos: u32, id: StatementId) -> VisitorResult { - let statement = &mut ctx.heap[id]; - match statement { - Statement::Local(stmt) => { - match stmt { - LocalStatement::Memory(local) => { - let variable_id = local.variable; - self.checked_add_local(ctx, relative_pos, variable_id)?; - }, - LocalStatement::Channel(local) => { - let from_id = local.from; - let to_id = local.to; - self.checked_add_local(ctx, relative_pos, from_id)?; - self.checked_add_local(ctx, relative_pos, to_id)?; - } - } - } - Statement::Labeled(stmt) => { - let stmt_id = stmt.this; - let body_id = stmt.body; - self.checked_add_label(ctx, relative_pos, self.in_sync, stmt_id)?; - self.visit_statement_for_locals_labels_and_in_sync(ctx, relative_pos, body_id)?; - }, - Statement::While(stmt) => { - stmt.in_sync = self.in_sync; - }, - _ => {}, - } - - return Ok(()) + fn pop_statement_scope(&mut self, scope_to_restore: (Scope, i32)) { + self.cur_scope = scope_to_restore.0; + self.relative_pos_in_block = scope_to_restore.1; } fn visit_definition_and_assign_local_ids(&mut self, ctx: &mut Ctx, definition_id: DefinitionId) { @@ -1439,16 +1522,16 @@ impl PassValidationLinking { let relative_var_pos = if var_idx < var_section.len() { ctx.heap[var_section[var_idx]].relative_pos_in_block } else { - u32::MAX + i32::MAX }; let relative_scope_pos = if scope_idx < scope_section.len() { - ctx.heap[scope_section[scope_idx]].as_block().relative_pos_in_parent + ctx.heap[scope_section[scope_idx]].as_block().scope_node.relative_pos_in_parent } else { - u32::MAX + i32::MAX }; - debug_assert!(!(relative_var_pos == u32::MAX && relative_scope_pos == u32::MAX)); + debug_assert!(!(relative_var_pos == i32::MAX && relative_scope_pos == i32::MAX)); // In certain cases the relative variable position is the same as // the scope position (insertion of binding variables). In that case @@ -1474,26 +1557,76 @@ impl PassValidationLinking { block_stmt.next_unique_id_in_scope = var_counter; } + fn resolve_pending_control_flow_targets(&mut self, ctx: &mut Ctx) -> Result<(), ParseError> { + for entry in &self.control_flow_stmts { + let stmt = &ctx.heap[entry.statement]; + + match stmt { + Statement::Break(stmt) => { + let stmt_id = stmt.this; + let target_while_id = Self::resolve_break_or_continue_target(ctx, entry, stmt.span, &stmt.label)?; + let target_while_stmt = &ctx.heap[target_while_id]; + let target_end_while_id = target_while_stmt.end_while; + debug_assert!(!target_end_while_id.is_invalid()); + + let break_stmt = &mut ctx.heap[stmt_id]; + break_stmt.target = target_end_while_id; + }, + Statement::Continue(stmt) => { + let stmt_id = stmt.this; + let target_while_id = Self::resolve_break_or_continue_target(ctx, entry, stmt.span, &stmt.label)?; + + let continue_stmt = &mut ctx.heap[stmt_id]; + continue_stmt.target = target_while_id; + }, + Statement::Goto(stmt) => { + let stmt_id = stmt.this; + let target_id = Self::find_label(entry.in_scope, ctx, &stmt.label)?; + let target_stmt = &ctx.heap[target_id]; + if entry.in_sync != target_stmt.in_sync { + // Nested sync not allowed. And goto can only go to + // outer scopes, so we must be escaping from a sync. + debug_assert!(target_stmt.in_sync.is_invalid()); // target not in sync + debug_assert!(!entry.in_sync.is_invalid()); // but the goto is in sync + let goto_stmt = &ctx.heap[stmt_id]; + let sync_stmt = &ctx.heap[entry.in_sync]; + return Err( + ParseError::new_error_str_at_span(&ctx.module().source, goto_stmt.span, "goto may not escape the surrounding synchronous block") + .with_info_str_at_span(&ctx.module().source, target_stmt.label.span, "this is the target of the goto statement") + .with_info_str_at_span(&ctx.module().source, sync_stmt.span, "which will jump past this statement") + ); + } + + let goto_stmt = &mut ctx.heap[stmt_id]; + goto_stmt.target = target_id; + }, + _ => unreachable!("cannot resolve control flow target for {:?}", stmt), + } + } + + return Ok(()) + } + //-------------------------------------------------------------------------- // Utilities //-------------------------------------------------------------------------- /// Adds a local variable to the current scope. It will also annotate the /// `Local` in the AST with its relative position in the block. - fn checked_add_local(&mut self, ctx: &mut Ctx, relative_pos: u32, id: VariableId) -> Result<(), ParseError> { - debug_assert!(self.cur_scope.is_block()); + fn checked_add_local(&mut self, ctx: &mut Ctx, target_scope: Scope, target_relative_pos: i32, id: VariableId) -> Result<(), ParseError> { + debug_assert!(target_scope.is_block()); let local = &ctx.heap[id]; - let mut scope = &self.cur_scope; + // We immediately go to the parent scope. We check the target scope + // in the call at the end. That is also where we check for collisions + // with symbols. + let block = &ctx.heap[target_scope.to_block()]; + let mut scope = block.scope_node.parent; + let mut cur_relative_pos = block.scope_node.relative_pos_in_parent; loop { - // We immediately go to the parent scope. We check the current scope - // in the call at the end. Likewise for checking the symbol table. - let block = &ctx.heap[scope.to_block()]; - - scope = &block.scope_node.parent; if let Scope::Definition(definition_id) = scope { // At outer scope, check parameters of function/component - for parameter_id in ctx.heap[*definition_id].parameters() { + for parameter_id in ctx.heap[definition_id].parameters() { let parameter = &ctx.heap[*parameter_id]; if local.identifier == parameter.identifier { return Err( @@ -1511,7 +1644,7 @@ impl PassValidationLinking { } // If here then the parent scope is a block scope - let local_relative_pos = ctx.heap[scope.to_block()].relative_pos_in_parent; + let block = &ctx.heap[scope.to_block()]; for other_local_id in &block.locals { let other_local = &ctx.heap[*other_local_id]; @@ -1519,7 +1652,7 @@ impl PassValidationLinking { // is defined in a higher-level scope, but later than the scope // in which the current variable resides. if local.this != *other_local_id && - local_relative_pos >= other_local.relative_pos_in_block && + cur_relative_pos >= other_local.relative_pos_in_block && local.identifier == other_local.identifier { // Collision within this scope return Err( @@ -1531,17 +1664,20 @@ impl PassValidationLinking { ); } } + + scope = block.scope_node.parent; + cur_relative_pos = block.scope_node.relative_pos_in_parent; } // No collisions in any of the parent scope, attempt to add to scope - self.checked_at_single_scope_add_local(ctx, self.cur_scope, relative_pos, id) + self.checked_at_single_scope_add_local(ctx, target_scope, target_relative_pos, id) } /// Adds a local variable to the specified scope. Will check the specified /// scope for variable conflicts and the symbol table for global conflicts. /// Will NOT check parent scopes of the specified scope. fn checked_at_single_scope_add_local( - &mut self, ctx: &mut Ctx, scope: Scope, relative_pos: u32, id: VariableId + &mut self, ctx: &mut Ctx, scope: Scope, relative_pos: i32, id: VariableId ) -> Result<(), ParseError> { // Check the symbol table for conflicts { @@ -1565,7 +1701,7 @@ impl PassValidationLinking { for other_local_id in &block.locals { let other_local = &ctx.heap[*other_local_id]; if local.this != other_local.this && - relative_pos >= other_local.relative_pos_in_block && + // relative_pos >= other_local.relative_pos_in_block && local.identifier == other_local.identifier { // Collision return Err( @@ -1590,7 +1726,7 @@ impl PassValidationLinking { /// Finds a variable in the visitor's scope that must appear before the /// specified relative position within that block. - fn find_variable(&self, ctx: &Ctx, mut relative_pos: u32, identifier: &Identifier) -> Result { + fn find_variable(&self, ctx: &Ctx, mut relative_pos: i32, identifier: &Identifier) -> Option { debug_assert!(self.cur_scope.is_block()); // No need to use iterator over namespaces if here @@ -1603,8 +1739,8 @@ impl PassValidationLinking { for local_id in &block.locals { let local = &ctx.heap[*local_id]; - if local.relative_pos_in_block <= relative_pos && identifier == &local.identifier { - return Ok(*local_id); + if local.relative_pos_in_block < relative_pos && identifier == &local.identifier { + return Some(*local_id); } } @@ -1617,7 +1753,7 @@ impl PassValidationLinking { for parameter_id in definition.parameters() { let parameter = &ctx.heap[*parameter_id]; if identifier == ¶meter.identifier { - return Ok(*parameter_id); + return Some(*parameter_id); } } }, @@ -1625,16 +1761,16 @@ impl PassValidationLinking { } // Variable could not be found - return Err(()) + return None } else { - relative_pos = block.relative_pos_in_parent; + relative_pos = block.scope_node.relative_pos_in_parent; } } } /// Adds a particular label to the current scope. Will return an error if /// there is another label with the same name visible in the current scope. - fn checked_add_label(&mut self, ctx: &mut Ctx, relative_pos: u32, in_sync: SynchronousStatementId, id: LabeledStatementId) -> Result<(), ParseError> { + fn checked_add_label(&mut self, ctx: &mut Ctx, relative_pos: i32, in_sync: SynchronousStatementId, id: LabeledStatementId) -> Result<(), ParseError> { debug_assert!(self.cur_scope.is_block()); // Make sure label is not defined within the current scope or any of the @@ -1677,13 +1813,12 @@ impl PassValidationLinking { /// Finds a particular labeled statement by its identifier. Once found it /// will make sure that the target label does not skip over any variable /// declarations within the scope in which the label was found. - fn find_label(&self, ctx: &Ctx, identifier: &Identifier) -> Result { - debug_assert!(self.cur_scope.is_block()); + fn find_label(mut scope: Scope, ctx: &Ctx, identifier: &Identifier) -> Result { + debug_assert!(scope.is_block()); - let mut scope = &self.cur_scope; loop { debug_assert!(scope.is_block(), "scope is not a block"); - let relative_scope_pos = ctx.heap[scope.to_block()].relative_pos_in_parent; + let relative_scope_pos = ctx.heap[scope.to_block()].scope_node.relative_pos_in_parent; let block = &ctx.heap[scope.to_block()]; for label_id in &block.labels { @@ -1707,7 +1842,7 @@ impl PassValidationLinking { } } - scope = &block.scope_node.parent; + scope = block.scope_node.parent; if !scope.is_block() { return Err(ParseError::new_error_str_at_span( &ctx.module().source, identifier.span, "could not find this label" @@ -1719,8 +1854,7 @@ impl PassValidationLinking { /// This function will check if the provided while statement ID has a block /// statement that is one of our current parents. - fn has_parent_while_scope(&self, ctx: &Ctx, id: WhileStatementId) -> bool { - let mut scope = &self.cur_scope; + fn has_parent_while_scope(mut scope: Scope, ctx: &Ctx, id: WhileStatementId) -> bool { let while_stmt = &ctx.heap[id]; loop { debug_assert!(scope.is_block()); @@ -1730,7 +1864,7 @@ impl PassValidationLinking { } let block = &ctx.heap[block]; - scope = &block.scope_node.parent; + scope = block.scope_node.parent; if !scope.is_block() { return false; } @@ -1743,17 +1877,17 @@ impl PassValidationLinking { /// ID will be returned, otherwise a parsing error is constructed. /// The provided input position should be the position of the break/continue /// statement. - fn resolve_break_or_continue_target(&self, ctx: &Ctx, span: InputSpan, label: &Option) -> Result { + fn resolve_break_or_continue_target(ctx: &Ctx, control_flow: &ControlFlowStatement, span: InputSpan, label: &Option) -> Result { let target = match label { Some(label) => { - let target_id = self.find_label(ctx, label)?; + let target_id = Self::find_label(control_flow.in_scope, ctx, label)?; // Make sure break target is a while statement let target = &ctx.heap[target_id]; if let Statement::While(target_stmt) = &ctx.heap[target.body] { // Even though we have a target while statement, the break might not be // present underneath this particular labeled while statement - if !self.has_parent_while_scope(ctx, target_stmt.this) { + if !Self::has_parent_while_scope(control_flow.in_scope, ctx, target_stmt.this) { return Err(ParseError::new_error_str_at_span( &ctx.module().source, label.span, "break statement is not nested under the target label's while statement" ).with_info_str_at_span( @@ -1773,13 +1907,13 @@ impl PassValidationLinking { None => { // Use the enclosing while statement, the break must be // nested within that while statement - if self.in_while.is_invalid() { + if control_flow.in_while.is_invalid() { return Err(ParseError::new_error_str_at_span( &ctx.module().source, span, "Break statement is not nested under a while loop" )); } - self.in_while + control_flow.in_while } }; @@ -1787,11 +1921,11 @@ impl PassValidationLinking { // make sure we will not break out of a synchronous block { let target_while = &ctx.heap[target]; - if target_while.in_sync != self.in_sync { + if target_while.in_sync != control_flow.in_sync { // Break is nested under while statement, so can only escape a // sync block if the sync is nested inside the while statement. - debug_assert!(!self.in_sync.is_invalid()); - let sync_stmt = &ctx.heap[self.in_sync]; + debug_assert!(!control_flow.in_sync.is_invalid()); + let sync_stmt = &ctx.heap[control_flow.in_sync]; return Err( ParseError::new_error_str_at_span(&ctx.module().source, span, "break may not escape the surrounding synchronous block") .with_info_str_at_span(&ctx.module().source, target_while.span, "the break escapes out of this loop") diff --git a/src/protocol/parser/token_parsing.rs b/src/protocol/parser/token_parsing.rs index 3e265270f8f0375198d1a606c0cdf7c51fe4bcce..47f02e7f6c2367f467a348b4363af2e28a590014 100644 --- a/src/protocol/parser/token_parsing.rs +++ b/src/protocol/parser/token_parsing.rs @@ -47,6 +47,7 @@ pub(crate) const KW_STMT_GOTO: &'static [u8] = b"goto"; pub(crate) const KW_STMT_RETURN: &'static [u8] = b"return"; pub(crate) const KW_STMT_SYNC: &'static [u8] = b"sync"; pub(crate) const KW_STMT_FORK: &'static [u8] = b"fork"; +pub(crate) const KW_STMT_SELECT: &'static [u8] = b"select"; pub(crate) const KW_STMT_OR: &'static [u8] = b"or"; pub(crate) const KW_STMT_NEW: &'static [u8] = b"new"; diff --git a/src/protocol/parser/visitor.rs b/src/protocol/parser/visitor.rs index edd9bf5e3989728d794ca63b514600e7b50953f8..3d62f1c603b4b29b8d640dee5b355d8186fc3383 100644 --- a/src/protocol/parser/visitor.rs +++ b/src/protocol/parser/visitor.rs @@ -133,6 +133,11 @@ pub(crate) trait Visitor { self.visit_fork_stmt(ctx, this) }, Statement::EndFork(_stmt) => Ok(()), + Statement::Select(stmt) => { + let this = stmt.this; + self.visit_select_stmt(ctx, this) + }, + Statement::EndSelect(_stmt) => Ok(()), Statement::Return(stmt) => { let this = stmt.this; self.visit_return_stmt(ctx, this) @@ -176,6 +181,7 @@ pub(crate) trait Visitor { fn visit_continue_stmt(&mut self, _ctx: &mut Ctx, _id: ContinueStatementId) -> VisitorResult { Ok(()) } fn visit_synchronous_stmt(&mut self, _ctx: &mut Ctx, _id: SynchronousStatementId) -> VisitorResult { Ok(()) } fn visit_fork_stmt(&mut self, _ctx: &mut Ctx, _id: ForkStatementId) -> VisitorResult { Ok(()) } + fn visit_select_stmt(&mut self, _ctx: &mut Ctx, _id: SelectStatementId) -> VisitorResult { Ok(()) } fn visit_return_stmt(&mut self, _ctx: &mut Ctx, _id: ReturnStatementId) -> VisitorResult { Ok(()) } fn visit_goto_stmt(&mut self, _ctx: &mut Ctx, _id: GotoStatementId) -> VisitorResult { Ok(()) } fn visit_new_stmt(&mut self, _ctx: &mut Ctx, _id: NewStatementId) -> VisitorResult { Ok(()) } diff --git a/src/protocol/tests/lexer.rs b/src/protocol/tests/lexer.rs deleted file mode 100644 index 2c7f4ef5d81bdcba5631a065803ca7cd1d10ca7c..0000000000000000000000000000000000000000 --- a/src/protocol/tests/lexer.rs +++ /dev/null @@ -1,109 +0,0 @@ -/// lexer.rs -/// -/// Simple tests for the lexer. Only tests the lexing of the input source and -/// the resulting AST without relying on the validation/typing pass - -use super::*; - -#[test] -fn test_disallowed_inference() { - Tester::new_single_source_expect_err( - "argument auto inference", - "func thing(auto arg) -> s32 { return 0; }" - ).error(|e| { e - .assert_msg_has(0, "inference is not allowed") - .assert_occurs_at(0, "auto arg"); - }); - - Tester::new_single_source_expect_err( - "return type auto inference", - "func thing(s32 arg) -> auto { return 0; }" - ).error(|e| { e - .assert_msg_has(0, "inference is not allowed") - .assert_occurs_at(0, "auto {"); - }); - - Tester::new_single_source_expect_err( - "implicit polymorph argument auto inference", - "func thing(in port) -> s32 { return port; }" - ).error(|e| { e - .assert_msg_has(0, "inference is not allowed") - .assert_occurs_at(0, "in port"); - }); - - Tester::new_single_source_expect_err( - "explicit polymorph argument auto inference", - "func thing(in port) -> s32 { return port; }" - ).error(|e| { e - .assert_msg_has(0, "inference is not allowed") - .assert_occurs_at(0, "auto> port"); - }); - - Tester::new_single_source_expect_err( - "implicit polymorph return type auto inference", - "func thing(in a, in b) -> in { return a; }" - ).error(|e| { e - .assert_msg_has(0, "inference is not allowed") - .assert_occurs_at(0, "in {"); - }); - - Tester::new_single_source_expect_err( - "explicit polymorph return type auto inference", - "func thing(in a) -> in { return a; }" - ).error(|e| { e - .assert_msg_has(0, "inference is not allowed") - .assert_occurs_at(0, "auto> {"); - }); -} - -#[test] -fn test_simple_struct_definition() { - Tester::new_single_source_expect_ok( - "empty struct", - "struct Foo{}" - ).for_struct("Foo", |t| { t.assert_num_fields(0); }); - - Tester::new_single_source_expect_ok( - "single field, no comma", - "struct Foo{ s32 field }" - ).for_struct("Foo", |t| { t - .assert_num_fields(1) - .for_field("field", |f| { - f.assert_parser_type("s32"); - }); - }); - - Tester::new_single_source_expect_ok( - "single field, with comma", - "struct Foo{ s32 field, }" - ).for_struct("Foo", |t| { t - .assert_num_fields(1) - .for_field("field", |f| { f - .assert_parser_type("s32"); - }); - }); - - Tester::new_single_source_expect_ok( - "multiple fields, no comma", - "struct Foo{ u8 a, s16 b, s32 c }" - ).for_struct("Foo", |t| { t - .assert_num_fields(3) - .for_field("a", |f| { f.assert_parser_type("u8"); }) - .for_field("b", |f| { f.assert_parser_type("s16"); }) - .for_field("c", |f| { f.assert_parser_type("s32"); }); - }); - - Tester::new_single_source_expect_ok( - "multiple fields, with comma", - "struct Foo{ - u8 a, - s16 b, - s32 c, - }" - ).for_struct("Foo", |t| { t - .assert_num_fields(3) - .for_field("a", |f| { f.assert_parser_type("u8"); }) - .for_field("b", |f| { f.assert_parser_type("s16"); }) - .for_field("c", |f| { f.assert_parser_type("s32"); }); - }); -} \ No newline at end of file diff --git a/src/protocol/tests/mod.rs b/src/protocol/tests/mod.rs index 70eb7af70543577dd8a83fc315e8ce7a29221c1e..0d2f9cfe3d391cc37a4d19dd7c29e0edd1fa786a 100644 --- a/src/protocol/tests/mod.rs +++ b/src/protocol/tests/mod.rs @@ -12,13 +12,13 @@ */ mod utils; -mod lexer; mod parser_binding; mod parser_imports; mod parser_inference; mod parser_literals; mod parser_monomorphs; -mod parser_types; +mod parser_type_declaration; +mod parser_type_layout; mod parser_validation; mod eval_binding; mod eval_calls; diff --git a/src/protocol/tests/parser_binding.rs b/src/protocol/tests/parser_binding.rs index 4aa276e9371a3b41772ecf8e7feb2b99b8c4b935..b097f3e748a6d077357c992b4bcb58d93461783a 100644 --- a/src/protocol/tests/parser_binding.rs +++ b/src/protocol/tests/parser_binding.rs @@ -60,6 +60,48 @@ fn test_incorrect_binding() { }); } +#[test] +fn test_incorrect_binding_variable() { + // Note: if variable already exists then it is interpreted at the binding + // expression as value. So the case where "the variable is already defined" + // results in no binding variable. + + Tester::new_single_source_expect_err("binding var in next scope", " + union Option{ Some(T), None } + func foo() -> bool { + auto opt = Option::Some(false); + if (let Option::Some(var) = opt) { + auto var = true; // should mismatch against binding 'var' + return var; + } + return false; + } + ").error(|e| { e + .assert_num(2) + .assert_msg_has(0, "variable name conflicts") + .assert_occurs_at(0, "var = true;") + .assert_occurs_at(1, "var) = opt"); + }); + + Tester::new_single_source_expect_err("binding var in nested scope", " + union LR{ L(A), R(B) } + func foo() -> u32 { + LR x = LR::L(5); + if (let LR::L(y) = x) { + if (true) { + auto y = 5; + return y; + } + } + } + ").error(|e| { e + .assert_num(2) + .assert_msg_has(0, "variable name conflicts") + .assert_occurs_at(0, "y = 5") + .assert_occurs_at(1, "y) = x"); + }); +} + #[test] fn test_boolean_ops_on_binding() { Tester::new_single_source_expect_ok("apply && to binding result", " diff --git a/src/protocol/tests/parser_inference.rs b/src/protocol/tests/parser_inference.rs index 44dcd85a51a0797fed6e79151940b2c57e4bd322..282c19626d328bee46016db1bfc4bd816a387c0c 100644 --- a/src/protocol/tests/parser_inference.rs +++ b/src/protocol/tests/parser_inference.rs @@ -468,6 +468,56 @@ fn test_failed_polymorph_inference() { ); } +#[test] +fn test_disallowed_inference() { + Tester::new_single_source_expect_err( + "argument auto inference", + "func thing(auto arg) -> s32 { return 0; }" + ).error(|e| { e + .assert_msg_has(0, "inference is not allowed") + .assert_occurs_at(0, "auto arg"); + }); + + Tester::new_single_source_expect_err( + "return type auto inference", + "func thing(s32 arg) -> auto { return 0; }" + ).error(|e| { e + .assert_msg_has(0, "inference is not allowed") + .assert_occurs_at(0, "auto {"); + }); + + Tester::new_single_source_expect_err( + "implicit polymorph argument auto inference", + "func thing(in port) -> s32 { return port; }" + ).error(|e| { e + .assert_msg_has(0, "inference is not allowed") + .assert_occurs_at(0, "in port"); + }); + + Tester::new_single_source_expect_err( + "explicit polymorph argument auto inference", + "func thing(in port) -> s32 { return port; }" + ).error(|e| { e + .assert_msg_has(0, "inference is not allowed") + .assert_occurs_at(0, "auto> port"); + }); + + Tester::new_single_source_expect_err( + "implicit polymorph return type auto inference", + "func thing(in a, in b) -> in { return a; }" + ).error(|e| { e + .assert_msg_has(0, "inference is not allowed") + .assert_occurs_at(0, "in {"); + }); + + Tester::new_single_source_expect_err( + "explicit polymorph return type auto inference", + "func thing(in a) -> in { return a; }" + ).error(|e| { e + .assert_msg_has(0, "inference is not allowed") + .assert_occurs_at(0, "auto> {"); + }); +} #[test] fn test_explicit_polymorph_argument() { diff --git a/src/protocol/tests/parser_type_declaration.rs b/src/protocol/tests/parser_type_declaration.rs new file mode 100644 index 0000000000000000000000000000000000000000..b1107cda6e426b7a2c754e1282b7c922b2b6271b --- /dev/null +++ b/src/protocol/tests/parser_type_declaration.rs @@ -0,0 +1,53 @@ +use super::*; + +#[test] +fn test_simple_struct_definition() { + Tester::new_single_source_expect_ok( + "empty struct", + "struct Foo{}" + ).for_struct("Foo", |t| { t.assert_num_fields(0); }); + + Tester::new_single_source_expect_ok( + "single field, no comma", + "struct Foo{ s32 field }" + ).for_struct("Foo", |t| { t + .assert_num_fields(1) + .for_field("field", |f| { + f.assert_parser_type("s32"); + }); + }); + + Tester::new_single_source_expect_ok( + "single field, with comma", + "struct Foo{ s32 field, }" + ).for_struct("Foo", |t| { t + .assert_num_fields(1) + .for_field("field", |f| { f + .assert_parser_type("s32"); + }); + }); + + Tester::new_single_source_expect_ok( + "multiple fields, no comma", + "struct Foo{ u8 a, s16 b, s32 c }" + ).for_struct("Foo", |t| { t + .assert_num_fields(3) + .for_field("a", |f| { f.assert_parser_type("u8"); }) + .for_field("b", |f| { f.assert_parser_type("s16"); }) + .for_field("c", |f| { f.assert_parser_type("s32"); }); + }); + + Tester::new_single_source_expect_ok( + "multiple fields, with comma", + "struct Foo{ + u8 a, + s16 b, + s32 c, + }" + ).for_struct("Foo", |t| { t + .assert_num_fields(3) + .for_field("a", |f| { f.assert_parser_type("u8"); }) + .for_field("b", |f| { f.assert_parser_type("s16"); }) + .for_field("c", |f| { f.assert_parser_type("s32"); }); + }); +} \ No newline at end of file diff --git a/src/protocol/tests/parser_types.rs b/src/protocol/tests/parser_type_layout.rs similarity index 100% rename from src/protocol/tests/parser_types.rs rename to src/protocol/tests/parser_type_layout.rs diff --git a/src/protocol/tests/parser_validation.rs b/src/protocol/tests/parser_validation.rs index 0462c7e1e7f81070d022ab9fa36ad21f66d6203c..f6cb386849c640aa79f36a3f3d96477e9161917b 100644 --- a/src/protocol/tests/parser_validation.rs +++ b/src/protocol/tests/parser_validation.rs @@ -4,6 +4,8 @@ use super::*; + + #[test] fn test_correct_struct_instance() { Tester::new_single_source_expect_ok( @@ -528,4 +530,254 @@ fn test_polymorph_array_types() { ).for_struct("Bar", |s| { s .for_field("inputs", |f| { f.assert_parser_type("in[]"); }); }); +} + +#[test] +fn test_correct_modifying_operators() { + // Not testing the types, just that it parses + Tester::new_single_source_expect_ok( + "valid uses", + " + func f() -> u32 { + auto a = 5; + a += 2; a -= 2; a *= 2; a /= 2; a %= 2; + a <<= 2; a >>= 2; + a |= 2; a &= 2; a ^= 2; + return a; + } + " + ); +} + +#[test] +fn test_incorrect_modifying_operators() { + Tester::new_single_source_expect_err( + "wrong declaration", + "func f() -> u8 { auto a += 2; return a; }" + ).error(|e| { e.assert_msg_has(0, "expected '='"); }); + + Tester::new_single_source_expect_err( + "inside function", + "func f(u32 a) -> u32 { auto b = 0; auto c = f(a += 2); }" + ).error(|e| { e.assert_msg_has(0, "assignments are statements"); }); + + Tester::new_single_source_expect_err( + "inside tuple", + "func f(u32 a) -> u32 { auto b = (a += 2, a /= 2); return 0; }" + ).error(|e| { e.assert_msg_has(0, "assignments are statements"); }); +} + +#[test] +fn test_variable_introduction_in_scope() { + Tester::new_single_source_expect_err( + "variable use before declaration", + "func f() -> u8 { return thing; auto thing = 5; }" + ).error(|e| { e.assert_msg_has(0, "unresolved variable"); }); + + Tester::new_single_source_expect_err( + "variable use in declaration", + "func f() -> u8 { auto thing = 5 + thing; return thing; }" + ).error(|e| { e.assert_msg_has(0, "unresolved variable"); }); + + Tester::new_single_source_expect_ok( + "variable use after declaration", + "func f() -> u8 { auto thing = 5; return thing; }" + ); + + Tester::new_single_source_expect_err( + "variable use of closed scope", + "func f() -> u8 { { auto thing = 5; } return thing; }" + ).error(|e| { e.assert_msg_has(0, "unresolved variable"); }); +} + +#[test] +fn test_correct_select_statement() { + + Tester::new_single_source_expect_ok( + "guard variable decl", + " + primitive f() { + channel unused -> input; + + u32 outer_value = 0; + sync select { + auto in_same_guard = get(input) -> {} // decl A1 + auto in_same_gaurd = get(input) -> {} // decl A2 + auto in_guard_and_block = get(input) -> {} // decl B1 + outer_value = get(input) -> { auto in_guard_and_block = outer_value; } // decl B2 + } + } + " + ); + + Tester::new_single_source_expect_ok( + "empty select", + "primitive f() { sync select {} }" + ); + + Tester::new_single_source_expect_ok( + "mixed uses", " + primitive f() { + channel unused_output -> input; + u32 outer_value = 0; + sync select { + outer_value = get(input) -> outer_value = 0; + auto new_value = get(input) -> { + outer_value = new_value; + } + get(input) + get(input) -> + outer_value = 8; + get(input) -> + {} + outer_value %= get(input) -> { + outer_value *= outer_value; + auto new_value = get(input); + outer_value += new_value; + } + } + } + " + ); +} + +#[test] +fn test_incorrect_select_statement() { + Tester::new_single_source_expect_err( + "outside sync", + "primitive f() { select {} }" + ).error(|e| { e + .assert_num(1) + .assert_occurs_at(0, "select") + .assert_msg_has(0, "inside sync blocks"); + }); + + Tester::new_single_source_expect_err( + "variable in previous block", + "primitive f() { + channel tx -> rx; + u32 a = 0; // this one will be shadowed + sync select { auto a = get(rx) -> {} } + }" + ).error(|e| { e + .assert_num(2) + .assert_occurs_at(0, "a = get").assert_msg_has(0, "variable name conflicts") + .assert_occurs_at(1, "a = 0").assert_msg_has(1, "Previous variable"); + }); + + Tester::new_single_source_expect_err( + "put inside arm", + "primitive f() { + channel a -> b; + sync select { put(a) -> {} } + }" + ).error(|e| { e + .assert_occurs_at(0, "put") + .assert_msg_has(0, "may not occur"); + }); +} + +#[test] +fn test_incorrect_goto_statement() { + Tester::new_single_source_expect_err( + "goto missing var in same scope", + "func f() -> u32 { + goto exit; + auto v = 5; + exit: return 0; + }" + ).error(|e| { e + .assert_num(3) + .assert_occurs_at(0, "exit;").assert_msg_has(0, "skips over a variable") + .assert_occurs_at(1, "exit:").assert_msg_has(1, "jumps to this label") + .assert_occurs_at(2, "v = 5").assert_msg_has(2, "skips over this variable"); + }); + + Tester::new_single_source_expect_err( + "goto missing var in outer scope", + "func f() -> u32 { + if (true) { + goto exit; + } + auto v = 0; + exit: return 1; + }" + ).error(|e| { e + .assert_num(3) + .assert_occurs_at(0, "exit;").assert_msg_has(0, "skips over a variable") + .assert_occurs_at(1, "exit:").assert_msg_has(1, "jumps to this label") + .assert_occurs_at(2, "v = 0").assert_msg_has(2, "skips over this variable"); + }); + + Tester::new_single_source_expect_err( + "goto jumping into scope", + "func f() -> u32 { + goto nested; + { + nested: return 0; + } + return 1; + }" + ).error(|e| { e + .assert_num(1) + .assert_occurs_at(0, "nested;") + .assert_msg_has(0, "could not find this label"); + }); + + Tester::new_single_source_expect_err( + "goto jumping outside sync", + "primitive f() { + sync { goto exit; } + exit: u32 v = 0; + }" + ).error(|e| { e + .assert_num(3) + .assert_occurs_at(0, "goto exit;").assert_msg_has(0, "not escape the surrounding sync") + .assert_occurs_at(1, "exit: u32 v").assert_msg_has(1, "target of the goto") + .assert_occurs_at(2, "sync {").assert_msg_has(2, "jump past this"); + }) +} + +#[test] +fn test_incorrect_while_statement() { + // Just testing the error cases caught at compile-time. Other ones need + // evaluation testing + Tester::new_single_source_expect_err( + "break wrong earlier loop", + "func f() -> u32 { + target: while (true) {} + while (true) { break target; } + return 0; + }" + ).error(|e| { e + .assert_num(2) + .assert_occurs_at(0, "target; }").assert_msg_has(0, "not nested under the target") + .assert_occurs_at(1, "target: while").assert_msg_has(1, "is found here"); + }); + + Tester::new_single_source_expect_err( + "break wrong later loop", + "func f() -> u32 { + while (true) { break target; } + target: while (true) {} + return 0; + }" + ).error(|e| { e + .assert_num(2) + .assert_occurs_at(0, "target; }").assert_msg_has(0, "not nested under the target") + .assert_occurs_at(1, "target: while").assert_msg_has(1, "is found here"); + }); + + Tester::new_single_source_expect_err( + "break outside of sync", + "primitive f() { + outer: while (true) { //mark + sync while(true) { break outer; } + } + }" + ).error(|e| { e + .assert_num(3) + .assert_occurs_at(0, "break outer;").assert_msg_has(0, "may not escape the surrounding") + .assert_occurs_at(1, "while (true) { //mark").assert_msg_has(1, "escapes out of this loop") + .assert_occurs_at(2, "sync while").assert_msg_has(2, "escape this synchronous block"); + }); } \ No newline at end of file diff --git a/src/protocol/tests/utils.rs b/src/protocol/tests/utils.rs index fa7153214d697f2f57231e79511d844c581a5442..63ff0f95929262c11c44f701f0f31015589658dd 100644 --- a/src/protocol/tests/utils.rs +++ b/src/protocol/tests/utils.rs @@ -535,6 +535,7 @@ impl<'a> FunctionTester<'a> { let mut found_local_id = None; if let Some(block_id) = wrapping_block_id { + // Found the right block, find the variable inside the block again let block_stmt = self.ctx.heap[block_id].as_block(); for local_id in &block_stmt.locals { let var = &self.ctx.heap[*local_id]; @@ -1194,6 +1195,12 @@ fn seek_expr_in_stmt bool>(heap: &Heap, start: StatementId let stmt = &heap[start]; match stmt { + Statement::Local(stmt) => { + match stmt { + LocalStatement::Memory(stmt) => seek_expr_in_expr(heap, stmt.initial_expr.upcast(), f), + LocalStatement::Channel(_) => None + } + } Statement::Block(stmt) => { for stmt_id in &stmt.statements { if let Some(id) = seek_expr_in_stmt(heap, *stmt_id, f) {