From f39f9c5e5873ab09d6ec3c062637e2901a74a691 2022-02-10 09:35:28 From: mh Date: 2022-02-10 09:35:28 Subject: [PATCH] WIP: Refactored scopes in AST, pending bugfixes --- diff --git a/src/protocol/ast.rs b/src/protocol/ast.rs index ca6efe6af2637d61f1d1d50fa3cd87285916d4d0..6b2eaef1f425c59ef163f966967fb4a41daff97b 100644 --- a/src/protocol/ast.rs +++ b/src/protocol/ast.rs @@ -158,6 +158,8 @@ define_new_ast_id!(CastExpressionId, ExpressionId, index(CastExpression, Express define_new_ast_id!(CallExpressionId, ExpressionId, index(CallExpression, Expression::Call, expressions), alloc(alloc_call_expression)); define_new_ast_id!(VariableExpressionId, ExpressionId, index(VariableExpression, Expression::Variable, expressions), alloc(alloc_variable_expression)); +define_aliased_ast_id!(ScopeId, Id, index(Scope, scopes), alloc(alloc_scope)); + #[derive(Debug)] pub struct Heap { // Root arena, contains the entry point for different modules. Each root @@ -170,6 +172,7 @@ pub struct Heap { pub(crate) definitions: Arena, pub(crate) statements: Arena, pub(crate) expressions: Arena, + pub(crate) scopes: Arena, } impl Heap { @@ -183,6 +186,7 @@ impl Heap { definitions: Arena::new(), statements: Arena::new(), expressions: Arena::new(), + scopes: Arena::new(), } } pub fn alloc_memory_statement( @@ -705,57 +709,47 @@ impl<'a> Iterator for ConcreteTypeIter<'a> { } #[derive(Debug, Clone, Copy)] -pub enum Scope { +pub enum ScopeAssociation { Definition(DefinitionId), - Regular(BlockStatementId), - Synchronous(SynchronousStatementId, BlockStatementId), -} - -impl Scope { - pub(crate) fn new_invalid() -> Scope { - return Scope::Definition(DefinitionId::new_invalid()); - } - - pub(crate) fn is_invalid(&self) -> bool { - match self { - Scope::Definition(id) => id.is_invalid(), - _ => false, - } - } - - pub fn is_block(&self) -> bool { - match &self { - Scope::Definition(_) => false, - Scope::Regular(_) => true, - Scope::Synchronous(_, _) => true, - } - } - pub fn to_block(&self) -> BlockStatementId { - match &self { - Scope::Regular(id) => *id, - Scope::Synchronous(_, id) => *id, - _ => panic!("unable to get BlockStatement from Scope") - } - } + Block(BlockStatementId), + If(IfStatementId, bool), // if true, then body of "if", otherwise body of "else" + While(WhileStatementId), + Synchronous(SynchronousStatementId), + SelectCase(SelectStatementId, u32), // index is select case } /// `ScopeNode` is a helper that links scopes in two directions. It doesn't /// actually contain any information associated with the scope, this may be /// found on the AST elements that `Scope` points to. #[derive(Debug, Clone)] -pub struct ScopeNode { - pub parent: Scope, - pub nested: Vec, +pub struct Scope { + // Relation to other scopes + pub this: ScopeId, + pub parent: Option, + pub nested: Vec, + // Locally available variables/labels + pub association: ScopeAssociation, + pub variables: Vec, + pub labels: Vec, + // Location trackers/counters pub relative_pos_in_parent: i32, + pub first_unique_id_in_scope: i32, + pub next_unique_id_in_scope: i32, } -impl ScopeNode { - pub(crate) fn new_invalid() -> Self { - ScopeNode{ - parent: Scope::new_invalid(), +impl Scope { + pub(crate) fn new_invalid(this: ScopeId) -> Self { + return Self{ + this, + parent: None, nested: Vec::new(), + association: ScopeAssociation::Definition(DefinitionId::new_invalid()), + variables: Vec::new(), + labels: Vec::new(), relative_pos_in_parent: -1, - } + first_unique_id_in_scope: -1, + next_unique_id_in_scope: -1, + }; } } @@ -774,7 +768,7 @@ pub struct Variable { pub parser_type: ParserType, pub identifier: Identifier, // Validator/linker - pub relative_pos_in_block: i32, + pub relative_pos_in_parent: i32, pub unique_id_in_scope: i32, // Temporary fix until proper bytecode/asm is generated } @@ -1020,6 +1014,7 @@ pub struct ComponentDefinition { pub poly_vars: Vec, // Parsing pub parameters: Vec, + pub scope: ScopeId, pub body: BlockStatementId, // Validation/linking pub num_expressions_in_body: i32, @@ -1033,7 +1028,8 @@ impl ComponentDefinition { ) -> Self { Self{ this, defined_in, span, variant, identifier, poly_vars, - parameters: Vec::new(), + parameters: Vec::new(), + scope: ScopeId::new_invalid(), body: BlockStatementId::new_invalid(), num_expressions_in_body: -1, } @@ -1052,8 +1048,9 @@ pub struct FunctionDefinition { pub identifier: Identifier, pub poly_vars: Vec, // Parser - pub return_types: Vec, + pub return_type: ParserType, pub parameters: Vec, + pub scope: ScopeId, pub body: BlockStatementId, // Validation/linking pub num_expressions_in_body: i32, @@ -1068,8 +1065,9 @@ impl FunctionDefinition { this, defined_in, builtin: false, span, identifier, poly_vars, - return_types: Vec::new(), + return_type: ParserType{ elements: Vec::new(), full_span: InputSpan::new() }, parameters: Vec::new(), + scope: ScopeId::new_invalid(), body: BlockStatementId::new_invalid(), num_expressions_in_body: -1, } @@ -1178,6 +1176,7 @@ impl Statement { | Statement::If(_) => unreachable!(), } } + } #[derive(Debug, Clone)] @@ -1189,11 +1188,7 @@ pub struct BlockStatement { pub statements: Vec, pub end_block: EndBlockStatementId, // Phase 2: linker - pub scope_node: ScopeNode, - pub first_unique_id_in_scope: i32, // Temporary fix until proper bytecode/asm is generated - pub next_unique_id_in_scope: i32, // Temporary fix until proper bytecode/asm is generated - pub locals: Vec, - pub labels: Vec, + pub scope: ScopeId, pub next: StatementId, } @@ -1263,7 +1258,7 @@ pub struct ChannelStatement { pub from: VariableId, // output pub to: VariableId, // input // Phase 2: linker - pub relative_pos_in_block: i32, + pub relative_pos_in_parent: i32, pub next: StatementId, } @@ -1274,7 +1269,7 @@ pub struct LabeledStatement { pub label: Identifier, pub body: StatementId, // Phase 2: linker - pub relative_pos_in_block: i32, + pub relative_pos_in_parent: i32, pub in_sync: SynchronousStatementId, // may be invalid } @@ -1284,11 +1279,17 @@ pub struct IfStatement { // Phase 1: parser pub span: InputSpan, // of the "if" keyword pub test: ExpressionId, - pub true_body: BlockStatementId, - pub false_body: Option, + pub true_case: IfStatementCase, + pub false_case: Option, pub end_if: EndIfStatementId, } +#[derive(Debug, Clone, Copy)] +pub struct IfStatementCase { + pub body: StatementId, + pub scope: ScopeId, +} + #[derive(Debug, Clone)] pub struct EndIfStatement { pub this: EndIfStatementId, @@ -1302,7 +1303,8 @@ pub struct WhileStatement { // Phase 1: parser pub span: InputSpan, // of the "while" keyword pub test: ExpressionId, - pub body: BlockStatementId, + pub scope: ScopeId, + pub body: StatementId, pub end_while: EndWhileStatementId, pub in_sync: SynchronousStatementId, // may be invalid } @@ -1340,7 +1342,8 @@ pub struct SynchronousStatement { pub this: SynchronousStatementId, // Phase 1: parser pub span: InputSpan, // of the "sync" keyword - pub body: BlockStatementId, + pub scope: ScopeId, + pub body: StatementId, pub end_sync: EndSynchronousStatementId, } @@ -1357,8 +1360,8 @@ pub struct ForkStatement { pub this: ForkStatementId, // Phase 1: parser pub span: InputSpan, // of the "fork" keyword - pub left_body: BlockStatementId, - pub right_body: Option, + pub left_body: StatementId, + pub right_body: Option, pub end_fork: EndForkStatementId, } @@ -1382,7 +1385,8 @@ pub struct SelectCase { // The guard statement of a `select` is either a MemoryStatement or an // ExpressionStatement. Nothing else is allowed by the initial parsing pub guard: StatementId, - pub block: BlockStatementId, + pub body: StatementId, + pub scope: ScopeId, // Phase 2: Validation and Linking pub involved_ports: Vec<(CallExpressionId, ExpressionId)>, // call to `get` and its port argument } diff --git a/src/protocol/ast_printer.rs b/src/protocol/ast_printer.rs index 5529cd3dde91d3b9834ab3cc4ecce3d8641eef9d..ef616a9504825f992232a4087d7898e01352722f 100644 --- a/src/protocol/ast_printer.rs +++ b/src/protocol/ast_printer.rs @@ -354,11 +354,8 @@ impl ASTWriter { self.kv(indent3).with_s_key("PolyVar").with_identifier_val(&poly_var_id); } - self.kv(indent2).with_s_key("ReturnParserTypes"); - for return_type in &def.return_types { - self.kv(indent3).with_s_key("ReturnParserType") - .with_custom_val(|s| write_parser_type(s, heap, return_type)); - } + self.kv(indent2).with_s_key("ReturnParserType") + .with_custom_val(|s| write_parser_type(s, heap, &def.return_type)); self.kv(indent2).with_s_key("Parameters"); for variable_id in &def.parameters { @@ -401,9 +398,7 @@ impl ASTWriter { self.kv(indent).with_id(PREFIX_BLOCK_STMT_ID, stmt.this.0.index) .with_s_key("Block"); self.kv(indent2).with_s_key("EndBlockID").with_disp_val(&stmt.end_block.0.index); - self.kv(indent2).with_s_key("FirstUniqueScopeID").with_disp_val(&stmt.first_unique_id_in_scope); - self.kv(indent2).with_s_key("NextUniqueScopeID").with_disp_val(&stmt.next_unique_id_in_scope); - self.kv(indent2).with_s_key("RelativePos").with_disp_val(&stmt.scope_node.relative_pos_in_parent); + self.kv(indent2).with_s_key("ScopeID").with_disp_val(&stmt.scope.index); self.kv(indent2).with_s_key("Statements"); for stmt_id in &stmt.statements { @@ -457,11 +452,11 @@ impl ASTWriter { self.write_expr(heap, stmt.test, indent3); self.kv(indent2).with_s_key("TrueBody"); - self.write_stmt(heap, stmt.true_body.upcast(), indent3); + self.write_stmt(heap, stmt.true_case.body, indent3); - if let Some(false_body) = stmt.false_body { + if let Some(false_body) = stmt.false_case { self.kv(indent2).with_s_key("FalseBody"); - self.write_stmt(heap, false_body.upcast(), indent3); + self.write_stmt(heap, false_body.body, indent3); } }, Statement::EndIf(stmt) => { @@ -480,7 +475,7 @@ impl ASTWriter { self.kv(indent2).with_s_key("Condition"); self.write_expr(heap, stmt.test, indent3); self.kv(indent2).with_s_key("Body"); - self.write_stmt(heap, stmt.body.upcast(), indent3); + self.write_stmt(heap, stmt.body, indent3); }, Statement::EndWhile(stmt) => { self.kv(indent).with_id(PREFIX_ENDWHILE_STMT_ID, stmt.this.0.index) @@ -509,7 +504,7 @@ impl ASTWriter { .with_s_key("Synchronous"); self.kv(indent2).with_s_key("EndSync").with_disp_val(&stmt.end_sync.0.index); self.kv(indent2).with_s_key("Body"); - self.write_stmt(heap, stmt.body.upcast(), indent3); + self.write_stmt(heap, stmt.body, indent3); }, Statement::EndSynchronous(stmt) => { self.kv(indent).with_id(PREFIX_ENDSYNC_STMT_ID, stmt.this.0.index) @@ -522,11 +517,11 @@ impl ASTWriter { .with_s_key("Fork"); self.kv(indent2).with_s_key("EndFork").with_disp_val(&stmt.end_fork.0.index); self.kv(indent2).with_s_key("LeftBody"); - self.write_stmt(heap, stmt.left_body.upcast(), indent3); + self.write_stmt(heap, stmt.left_body, indent3); if let Some(right_body_id) = stmt.right_body { self.kv(indent2).with_s_key("RightBody"); - self.write_stmt(heap, right_body_id.upcast(), indent3); + self.write_stmt(heap, right_body_id, indent3); } }, Statement::EndFork(stmt) => { @@ -547,7 +542,7 @@ impl ASTWriter { self.write_stmt(heap, case.guard, indent4); self.kv(indent3).with_s_key("Block"); - self.write_stmt(heap, case.block.upcast(), indent4); + self.write_stmt(heap, case.body, indent4); } }, Statement::EndSelect(stmt) => { @@ -825,7 +820,7 @@ impl ASTWriter { self.kv(indent2).with_s_key("Kind").with_debug_val(&var.kind); self.kv(indent2).with_s_key("ParserType") .with_custom_val(|w| write_parser_type(w, heap, &var.parser_type)); - self.kv(indent2).with_s_key("RelativePos").with_disp_val(&var.relative_pos_in_block); + self.kv(indent2).with_s_key("RelativePos").with_disp_val(&var.relative_pos_in_parent); self.kv(indent2).with_s_key("UniqueScopeID").with_disp_val(&var.unique_id_in_scope); } diff --git a/src/protocol/eval/executor.rs b/src/protocol/eval/executor.rs index 78ab516f15000eec257786982eaa2250d7c3cb02..ac845a93c1640ab5270a557c106b806c13b5618b 100644 --- a/src/protocol/eval/executor.rs +++ b/src/protocol/eval/executor.rs @@ -38,35 +38,34 @@ impl Frame { /// Creates a new execution frame. Does not modify the stack in any way. pub fn new(heap: &Heap, definition_id: DefinitionId, monomorph_idx: i32) -> Self { let definition = &heap[definition_id]; - let first_statement = match definition { - Definition::Component(definition) => definition.body, - Definition::Function(definition) => definition.body, + let (outer_scope_id, first_statement_id) = match definition { + Definition::Component(definition) => (definition.scope, definition.body), + Definition::Function(definition) => (definition.scope, definition.body), _ => unreachable!("initializing frame with {:?} instead of a function/component", definition), }; // Another not-so-pretty thing that has to be replaced somewhere in the // future... - fn determine_max_stack_size(heap: &Heap, block_id: BlockStatementId, max_size: &mut u32) { - let block_stmt = &heap[block_id]; - debug_assert!(block_stmt.next_unique_id_in_scope >= 0); + fn determine_max_stack_size(heap: &Heap, scope_id: ScopeId, max_size: &mut u32) { + let scope = &heap[scope_id]; // Check current block - let cur_size = block_stmt.next_unique_id_in_scope as u32; + let cur_size = scope.next_unique_id_in_scope as u32; if cur_size > *max_size { *max_size = cur_size; } // And child blocks - for child_scope in &block_stmt.scope_node.nested { - determine_max_stack_size(heap, child_scope.to_block(), max_size); + for child_scope in &scope.nested { + determine_max_stack_size(heap, *child_scope, max_size); } } let mut max_stack_size = 0; - determine_max_stack_size(heap, first_statement, &mut max_stack_size); + determine_max_stack_size(heap, outer_scope_id, &mut max_stack_size); Frame{ definition: definition_id, monomorph_idx, - position: first_statement.upcast(), + position: first_statement_id.upcast(), expr_stack: VecDeque::with_capacity(128), expr_values: VecDeque::with_capacity(128), max_stack_size, @@ -826,7 +825,8 @@ impl Prompt { }, Statement::EndBlock(stmt) => { let block = &heap[stmt.start_block]; - self.store.clear_stack(block.first_unique_id_in_scope as usize); + let scope = &heap[block.scope]; + self.store.clear_stack(scope.first_unique_id_in_scope as usize); cur_frame.position = stmt.next; Ok(EvalContinuation::Stepping) @@ -873,9 +873,9 @@ impl Prompt { let test_value = cur_frame.expr_values.pop_back().unwrap(); let test_value = self.store.maybe_read_ref(&test_value).as_bool(); if test_value { - cur_frame.position = stmt.true_body.upcast(); - } else if let Some(false_body) = stmt.false_body { - cur_frame.position = false_body.upcast(); + cur_frame.position = stmt.true_case.body; + } else if let Some(false_body) = stmt.false_case { + cur_frame.position = false_body.body; } else { // Not true, and no false body cur_frame.position = stmt.end_if.upcast(); @@ -885,6 +885,13 @@ impl Prompt { }, Statement::EndIf(stmt) => { cur_frame.position = stmt.next; + let if_stmt = &heap[stmt.start_if]; + debug_assert_eq!( + heap[if_stmt.true_case.scope].first_unique_id_in_scope, + heap[if_stmt.false_case.unwrap_or(if_stmt.true_case).scope].first_unique_id_in_scope, + ); + let scope = &heap[if_stmt.true_case.scope]; + self.store.clear_stack(scope.first_unique_id_in_scope as usize); Ok(EvalContinuation::Stepping) }, Statement::While(stmt) => { @@ -892,7 +899,7 @@ impl Prompt { let test_value = cur_frame.expr_values.pop_back().unwrap(); let test_value = self.store.maybe_read_ref(&test_value).as_bool(); if test_value { - cur_frame.position = stmt.body.upcast(); + cur_frame.position = stmt.body; } else { cur_frame.position = stmt.end_while.upcast(); } @@ -901,7 +908,9 @@ impl Prompt { }, Statement::EndWhile(stmt) => { cur_frame.position = stmt.next; - + let start_while = &heap[stmt.start_while]; + let scope = &heap[start_while.scope]; + self.store.clear_stack(scope.first_unique_id_in_scope as usize); Ok(EvalContinuation::Stepping) }, Statement::Break(stmt) => { @@ -915,27 +924,30 @@ impl Prompt { Ok(EvalContinuation::Stepping) }, Statement::Synchronous(stmt) => { - cur_frame.position = stmt.body.upcast(); + cur_frame.position = stmt.body; Ok(EvalContinuation::SyncBlockStart) }, Statement::EndSynchronous(stmt) => { cur_frame.position = stmt.next; + let start_synchronous = &heap[stmt.start_sync]; + let scope = &heap[start_synchronous.scope]; + self.store.clear_stack(scope.first_unique_id_in_scope as usize); Ok(EvalContinuation::SyncBlockEnd) }, Statement::Fork(stmt) => { if stmt.right_body.is_none() { // No reason to fork - cur_frame.position = stmt.left_body.upcast(); + cur_frame.position = stmt.left_body; } else { // Need to fork if let Some(go_left) = ctx.performed_fork() { // Runtime has created a fork if go_left { - cur_frame.position = stmt.left_body.upcast(); + cur_frame.position = stmt.left_body; } else { - cur_frame.position = stmt.right_body.unwrap().upcast(); + cur_frame.position = stmt.right_body.unwrap(); } } else { // Request the runtime to create a fork of the current @@ -956,6 +968,11 @@ impl Prompt { }, Statement::EndSelect(stmt) => { cur_frame.position = stmt.next; + let start_select = &heap[stmt.start_select]; + if let Some(select_case) = start_select.cases.first() { + let scope = &heap[select_case.scope]; + self.store.clear_stack(scope.first_unique_id_in_scope as usize); + } Ok(EvalContinuation::Stepping) }, diff --git a/src/protocol/eval/value.rs b/src/protocol/eval/value.rs index b6c8b42c3ca0ec60c9cbcc702bdd854b4cb5fc35..9d9d1736c3a51487dd77f3215782dd1149e0671b 100644 --- a/src/protocol/eval/value.rs +++ b/src/protocol/eval/value.rs @@ -520,7 +520,7 @@ pub(crate) fn apply_unary_operator(store: &mut Store, op: UnaryOperator, value: Value::SInt32(v) => Value::SInt32($apply *v), Value::SInt64(v) => Value::SInt64($apply *v), _ => unreachable!("apply_unary_operator {:?} on value {:?}", $op, $value), - }; + } } } @@ -542,7 +542,7 @@ pub(crate) fn apply_unary_operator(store: &mut Store, op: UnaryOperator, value: _ => unreachable!("apply_unary_operator {:?} on value {:?}", op, value), } }, - UO::BitwiseNot => { apply_int_expr_and_return!(value, !, op)}, + UO::BitwiseNot => { apply_int_expr_and_return!(value, !, op); }, UO::LogicalNot => { return Value::Bool(!value.as_bool()); }, } } diff --git a/src/protocol/parser/mod.rs b/src/protocol/parser/mod.rs index 4688c9fe5e963e7877606b9384ae685762c2521c..3dca61f0109c63bbe526e10dfd69e8e9cbf6feeb 100644 --- a/src/protocol/parser/mod.rs +++ b/src/protocol/parser/mod.rs @@ -282,23 +282,24 @@ fn insert_builtin_function (Vec<(&'static str, Pa span: InputSpan::new(), identifier: Identifier{ span: InputSpan::new(), value: func_ident_ref.clone() }, poly_vars, - return_types: Vec::new(), + return_type: ParserType{ elements: Vec::new(), full_span: InputSpan::new() }, parameters: Vec::new(), + scope: ScopeId::new_invalid(), body: BlockStatementId::new_invalid(), num_expressions_in_body: -1, }); - let (args, ret) = arg_and_return_fn(func_id); + let (arguments, return_type) = arg_and_return_fn(func_id); - let mut parameters = Vec::with_capacity(args.len()); - for (arg_name, arg_type) in args { + let mut parameters = Vec::with_capacity(arguments.len()); + for (arg_name, arg_type) in arguments { let identifier = Identifier{ span: InputSpan::new(), value: p.string_pool.intern(arg_name.as_bytes()) }; let param_id = p.heap.alloc_variable(|this| Variable{ this, kind: VariableKind::Parameter, parser_type: arg_type.clone(), identifier, - relative_pos_in_block: 0, + relative_pos_in_parent: 0, unique_id_in_scope: 0 }); parameters.push(param_id); @@ -306,7 +307,7 @@ fn insert_builtin_function (Vec<(&'static str, Pa let func = &mut p.heap[func_id]; func.parameters = parameters; - func.return_types.push(ret); + func.return_type = return_type; p.symbol_table.insert_symbol(SymbolScope::Global, Symbol{ name: func_ident_ref, diff --git a/src/protocol/parser/pass_definitions.rs b/src/protocol/parser/pass_definitions.rs index 66ed638cc16da38f731d12225efd5d7c0a7da217..3b0d9abe5004eb827700c49d090cd32e560feba1 100644 --- a/src/protocol/parser/pass_definitions.rs +++ b/src/protocol/parser/pass_definitions.rs @@ -273,33 +273,18 @@ impl PassDefinitions { // Consume return types consume_token(&module.source, iter, TokenKind::ArrowRight)?; - let mut return_types = self.parser_types.start_section(); - let mut open_curly_pos = iter.last_valid_pos(); // bogus value - consume_comma_separated_until( - TokenKind::OpenCurly, &module.source, iter, ctx, - |source, iter, ctx| { - let poly_vars = ctx.heap[definition_id].poly_vars(); - self.type_parser.consume_parser_type( - iter, &ctx.heap, source, &ctx.symbols, poly_vars, definition_id, - module_scope, false, None - ) - }, - &mut return_types, "a return type", Some(&mut open_curly_pos) + let poly_vars = ctx.heap[definition_id].poly_vars(); + let parser_type = self.type_parser.consume_parser_type( + iter, &ctx.heap, &module.source, &ctx.symbols, poly_vars, definition_id, + module_scope, false, None )?; - let return_types = return_types.into_vec(); - - match return_types.len() { - 0 => return Err(ParseError::new_error_str_at_pos(&module.source, open_curly_pos, "expected a return type")), - 1 => {}, - _ => return Err(ParseError::new_error_str_at_pos(&module.source, open_curly_pos, "multiple return types are not (yet) allowed")), - } - // Consume block - let body = self.consume_block_statement_without_leading_curly(module, iter, ctx, open_curly_pos)?; + // Consume block and the definition's scope + let body = self.consume_block_statement(module, iter, ctx)?; // Assign everything in the preallocated AST node let function = ctx.heap[definition_id].as_function_mut(); - function.return_types = return_types; + function.return_type = parser_type; function.parameters = parameters; function.body = body; @@ -340,187 +325,136 @@ impl PassDefinitions { Ok(()) } - /// Consumes a block statement. If the resulting statement is not a block - /// (e.g. for a shorthand "if (expr) single_statement") then it will be - /// wrapped in one - fn consume_block_or_wrapped_statement( - &mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx - ) -> Result { - if Some(TokenKind::OpenCurly) == iter.next() { - // This is a block statement - self.consume_block_statement(module, iter, ctx) - } else { - // Not a block statement, so wrap it in one - let mut statements = self.statements.start_section(); - let wrap_begin_pos = iter.last_valid_pos(); - self.consume_statement(module, iter, ctx, &mut statements)?; - let wrap_end_pos = iter.last_valid_pos(); - - let statements = statements.into_vec(); - - let id = ctx.heap.alloc_block_statement(|this| BlockStatement{ - this, - is_implicit: true, - span: InputSpan::from_positions(wrap_begin_pos, wrap_end_pos), - statements, - end_block: EndBlockStatementId::new_invalid(), - scope_node: ScopeNode::new_invalid(), - first_unique_id_in_scope: -1, - next_unique_id_in_scope: -1, - locals: Vec::new(), - labels: Vec::new(), - next: StatementId::new_invalid(), - }); - - let end_block = ctx.heap.alloc_end_block_statement(|this| EndBlockStatement{ - this, start_block: id, next: StatementId::new_invalid() - }); - - let block_stmt = &mut ctx.heap[id]; - block_stmt.end_block = end_block; - - Ok(id) - } - } - /// Consumes a statement and returns a boolean indicating whether it was a /// block or not. - fn consume_statement( - &mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx, section: &mut ScopedSection - ) -> Result<(), ParseError> { + fn consume_statement(&mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx) -> Result { let next = iter.next().expect("consume_statement has a next token"); if next == TokenKind::OpenCurly { let id = self.consume_block_statement(module, iter, ctx)?; - section.push(id.upcast()); + return Ok(id.upcast()); } else if next == TokenKind::Ident { let ident = peek_ident(&module.source, iter).unwrap(); if ident == KW_STMT_IF { // Consume if statement and place end-if statement directly // after it. let id = self.consume_if_statement(module, iter, ctx)?; - section.push(id.upcast()); let end_if = ctx.heap.alloc_end_if_statement(|this| EndIfStatement { this, start_if: id, next: StatementId::new_invalid() }); - section.push(end_if.upcast()); let if_stmt = &mut ctx.heap[id]; if_stmt.end_if = end_if; + + return Ok(id.upcast()); } else if ident == KW_STMT_WHILE { let id = self.consume_while_statement(module, iter, ctx)?; - section.push(id.upcast()); let end_while = ctx.heap.alloc_end_while_statement(|this| EndWhileStatement { this, start_while: id, next: StatementId::new_invalid() }); - section.push(end_while.upcast()); let while_stmt = &mut ctx.heap[id]; while_stmt.end_while = end_while; + + return Ok(id.upcast()); } else if ident == KW_STMT_BREAK { let id = self.consume_break_statement(module, iter, ctx)?; - section.push(id.upcast()); + return Ok(id.upcast()); } else if ident == KW_STMT_CONTINUE { let id = self.consume_continue_statement(module, iter, ctx)?; - section.push(id.upcast()); + return Ok(id.upcast()); } else if ident == KW_STMT_SYNC { let id = self.consume_synchronous_statement(module, iter, ctx)?; - section.push(id.upcast()); let end_sync = ctx.heap.alloc_end_synchronous_statement(|this| EndSynchronousStatement { this, start_sync: id, next: StatementId::new_invalid() }); - section.push(end_sync.upcast()); let sync_stmt = &mut ctx.heap[id]; sync_stmt.end_sync = end_sync; + + return Ok(id.upcast()); } else if ident == KW_STMT_FORK { let id = self.consume_fork_statement(module, iter, ctx)?; - section.push(id.upcast()); let end_fork = ctx.heap.alloc_end_fork_statement(|this| EndForkStatement { this, start_fork: id, next: StatementId::new_invalid(), }); - section.push(end_fork.upcast()); let fork_stmt = &mut ctx.heap[id]; fork_stmt.end_fork = end_fork; + + return Ok(id.upcast()); } else if ident == KW_STMT_SELECT { let id = self.consume_select_statement(module, iter, ctx)?; - section.push(id.upcast()); let end_select = ctx.heap.alloc_end_select_statement(|this| EndSelectStatement{ this, start_select: id, next: StatementId::new_invalid(), }); - section.push(end_select.upcast()); let select_stmt = &mut ctx.heap[id]; select_stmt.end_select = end_select; + + return Ok(id.upcast()); } else if ident == KW_STMT_RETURN { let id = self.consume_return_statement(module, iter, ctx)?; - section.push(id.upcast()); + return Ok(id.upcast()); } else if ident == KW_STMT_GOTO { let id = self.consume_goto_statement(module, iter, ctx)?; - section.push(id.upcast()); + return Ok(id.upcast()); } else if ident == KW_STMT_NEW { let id = self.consume_new_statement(module, iter, ctx)?; - section.push(id.upcast()); + return Ok(id.upcast()); } else if ident == KW_STMT_CHANNEL { let id = self.consume_channel_statement(module, iter, ctx)?; - section.push(id.upcast().upcast()); + return Ok(id.upcast().upcast()); } else if iter.peek() == Some(TokenKind::Colon) { - self.consume_labeled_statement(module, iter, ctx, section)?; + let id = self.consume_labeled_statement(module, iter, ctx)?; + return Ok(id.upcast()); } else { // Two fallback possibilities: the first one is a memory // declaration, the other one is to parse it as a normal // expression. This is a bit ugly. if let Some(memory_stmt_id) = self.maybe_consume_memory_statement_without_semicolon(module, iter, ctx)? { consume_token(&module.source, iter, TokenKind::SemiColon)?; - section.push(memory_stmt_id.upcast().upcast()); + return Ok(memory_stmt_id.upcast().upcast()); } else { let id = self.consume_expression_statement(module, iter, ctx)?; - section.push(id.upcast()); + return Ok(id.upcast()); } } } else if next == TokenKind::OpenParen { // Same as above: memory statement or normal expression if let Some(memory_stmt_id) = self.maybe_consume_memory_statement_without_semicolon(module, iter, ctx)? { consume_token(&module.source, iter, TokenKind::SemiColon)?; - section.push(memory_stmt_id.upcast().upcast()); + return Ok(memory_stmt_id.upcast().upcast()); } else { let id = self.consume_expression_statement(module, iter, ctx)?; - section.push(id.upcast()); + return Ok(id.upcast()); } } else { let id = self.consume_expression_statement(module, iter, ctx)?; - section.push(id.upcast()); + return Ok(id.upcast()); } - - return Ok(()); } fn consume_block_statement( &mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx ) -> Result { - let open_span = consume_token(&module.source, iter, TokenKind::OpenCurly)?; - self.consume_block_statement_without_leading_curly(module, iter, ctx, open_span.begin) - } + let open_curly_span = consume_token(&module.source, iter, TokenKind::OpenCurly)?; - fn consume_block_statement_without_leading_curly( - &mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx, open_curly_pos: InputPosition - ) -> Result { let mut stmt_section = self.statements.start_section(); let mut next = iter.next(); while next != Some(TokenKind::CloseCurly) { @@ -529,13 +463,14 @@ impl PassDefinitions { &module.source, iter.last_valid_pos(), "expected a statement or '}'" )); } - self.consume_statement(module, iter, ctx, &mut stmt_section)?; + let stmt_id = self.consume_statement(module, iter, ctx)?; + stmt_section.push(stmt_id); next = iter.next(); } let statements = stmt_section.into_vec(); let mut block_span = consume_token(&module.source, iter, TokenKind::CloseCurly)?; - block_span.begin = open_curly_pos; + block_span.begin = open_curly_span.begin; let id = ctx.heap.alloc_block_statement(|this| BlockStatement{ this, @@ -543,11 +478,7 @@ impl PassDefinitions { span: block_span, statements, end_block: EndBlockStatementId::new_invalid(), - scope_node: ScopeNode::new_invalid(), - first_unique_id_in_scope: -1, - next_unique_id_in_scope: -1, - locals: Vec::new(), - labels: Vec::new(), + scope: ScopeId::new_invalid(), next: StatementId::new_invalid(), }); @@ -568,24 +499,36 @@ impl PassDefinitions { consume_token(&module.source, iter, TokenKind::OpenParen)?; let test = self.consume_expression(module, iter, ctx)?; consume_token(&module.source, iter, TokenKind::CloseParen)?; - let true_body = self.consume_block_or_wrapped_statement(module, iter, ctx)?; - let false_body = if has_ident(&module.source, iter, KW_STMT_ELSE) { + let true_body = IfStatementCase{ + body: self.consume_statement(module, iter, ctx)?, + scope: ScopeId::new_invalid(), + }; + let true_body_scope_id = true_body.scope; + + let (false_body, false_body_scope_id) = if has_ident(&module.source, iter, KW_STMT_ELSE) { iter.consume(); - let false_body = self.consume_block_or_wrapped_statement(module, iter, ctx)?; - Some(false_body) + let false_body = IfStatementCase{ + body: self.consume_statement(module, iter, ctx)?, + scope: ScopeId::new_invalid(), + }; + + let false_body_scope_id = false_body.scope; + (Some(false_body), Some(false_body_scope_id)) } else { - None + (None, None) }; - Ok(ctx.heap.alloc_if_statement(|this| IfStatement{ + let if_stmt_id = ctx.heap.alloc_if_statement(|this| IfStatement{ this, span: if_span, test, - true_body, - false_body, + true_case: true_body, + false_case: false_body, end_if: EndIfStatementId::new_invalid(), - })) + }); + + return Ok(if_stmt_id); } fn consume_while_statement( @@ -595,12 +538,13 @@ impl PassDefinitions { consume_token(&module.source, iter, TokenKind::OpenParen)?; let test = self.consume_expression(module, iter, ctx)?; consume_token(&module.source, iter, TokenKind::CloseParen)?; - let body = self.consume_block_or_wrapped_statement(module, iter, ctx)?; + let body = self.consume_statement(module, iter, ctx)?; Ok(ctx.heap.alloc_while_statement(|this| WhileStatement{ this, span: while_span, test, + scope: ScopeId::new_invalid(), body, end_while: EndWhileStatementId::new_invalid(), in_sync: SynchronousStatementId::new_invalid(), @@ -649,11 +593,12 @@ impl PassDefinitions { &mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx ) -> Result { let synchronous_span = consume_exact_ident(&module.source, iter, KW_STMT_SYNC)?; - let body = self.consume_block_or_wrapped_statement(module, iter, ctx)?; + let body = self.consume_statement(module, iter, ctx)?; Ok(ctx.heap.alloc_synchronous_statement(|this| SynchronousStatement{ this, span: synchronous_span, + scope: ScopeId::new_invalid(), body, end_sync: EndSynchronousStatementId::new_invalid(), })) @@ -663,11 +608,11 @@ impl PassDefinitions { &mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx ) -> Result { let fork_span = consume_exact_ident(&module.source, iter, KW_STMT_FORK)?; - let left_body = self.consume_block_or_wrapped_statement(module, iter, ctx)?; + let left_body = self.consume_statement(module, iter, ctx)?; let right_body = if has_ident(&module.source, iter, KW_STMT_OR) { iter.consume(); - let right_body = self.consume_block_or_wrapped_statement(module, iter, ctx)?; + let right_body = self.consume_statement(module, iter, ctx)?; Some(right_body) } else { None @@ -710,9 +655,11 @@ impl PassDefinitions { }, }; consume_token(&module.source, iter, TokenKind::ArrowRight)?; - let block = self.consume_block_or_wrapped_statement(module, iter, ctx)?; + let block = self.consume_statement(module, iter, ctx)?; cases.push(SelectCase{ - guard, block, + guard, + body: block, + scope: ScopeId::new_invalid(), involved_ports: Vec::with_capacity(1) }); @@ -855,7 +802,7 @@ impl PassDefinitions { kind: VariableKind::Local, identifier: from_identifier, parser_type: from_port_type, - relative_pos_in_block: 0, + relative_pos_in_parent: 0, unique_id_in_scope: -1, }); @@ -870,7 +817,7 @@ impl PassDefinitions { kind: VariableKind::Local, identifier: to_identifier, parser_type: to_port_type, - relative_pos_in_block: 0, + relative_pos_in_parent: 0, unique_id_in_scope: -1, }); @@ -879,49 +826,25 @@ impl PassDefinitions { this, span: channel_span, from, to, - relative_pos_in_block: 0, + relative_pos_in_parent: 0, next: StatementId::new_invalid(), })) } - fn consume_labeled_statement( - &mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx, section: &mut ScopedSection - ) -> Result<(), ParseError> { + fn consume_labeled_statement(&mut self, module: &Module, iter: &mut TokenIter, ctx: &mut PassCtx) -> Result { let label = consume_ident_interned(&module.source, iter, ctx)?; consume_token(&module.source, iter, TokenKind::Colon)?; - // Not pretty: consume_statement may produce more than one statement. - // The values in the section need to be in the correct order if some - // kind of outer block is consumed, so we take another section, push - // the expressions in that one, and then allocate the labeled statement. - let mut inner_section = self.statements.start_section(); - self.consume_statement(module, iter, ctx, &mut inner_section)?; - debug_assert!(inner_section.len() >= 1); - + let inner_stmt_id = self.consume_statement(module, iter, ctx)?; let stmt_id = ctx.heap.alloc_labeled_statement(|this| LabeledStatement { this, label, - body: inner_section[0], - relative_pos_in_block: 0, + body: inner_stmt_id, + relative_pos_in_parent: 0, in_sync: SynchronousStatementId::new_invalid(), }); - if inner_section.len() == 1 { - // Produce the labeled statement pointing to the first statement. - // This is by far the most common case. - inner_section.forget(); - section.push(stmt_id.upcast()); - } else { - // Produce the labeled statement using the first statement, and push - // the remaining ones at the end. - let inner_statements = inner_section.into_vec(); - section.push(stmt_id.upcast()); - for idx in 1..inner_statements.len() { - section.push(inner_statements[idx]) - } - } - - Ok(()) + return Ok(stmt_id); } /// Attempts to consume a memory statement (a statement along the lines of @@ -960,7 +883,7 @@ impl PassDefinitions { kind: VariableKind::Local, identifier: identifier.clone(), parser_type, - relative_pos_in_block: 0, + relative_pos_in_parent: 0, unique_id_in_scope: -1, }); @@ -1883,7 +1806,7 @@ fn consume_parameter_list( kind: VariableKind::Parameter, parser_type, identifier, - relative_pos_in_block: 0, + relative_pos_in_parent: 0, unique_id_in_scope: -1, }); Ok(parameter_id) diff --git a/src/protocol/parser/pass_rewriting.rs b/src/protocol/parser/pass_rewriting.rs index 517de373802e3d79c4936c3afa27b312898e606b..80554d91b426312cf0301354c1b037d4f10e5f09 100644 --- a/src/protocol/parser/pass_rewriting.rs +++ b/src/protocol/parser/pass_rewriting.rs @@ -59,12 +59,12 @@ impl Visitor for PassRewriting { fn visit_if_stmt(&mut self, ctx: &mut Ctx, id: IfStatementId) -> VisitorResult { let if_stmt = &ctx.heap[id]; - let true_body_id = if_stmt.true_body; - let false_body_id = if_stmt.false_body; + let true_body_id = if_stmt.true_case; + let false_body_id = if_stmt.false_case; - self.visit_block_stmt(ctx, true_body_id)?; + self.visit_stmt(ctx, true_body_id.body)?; if let Some(false_body_id) = false_body_id { - self.visit_block_stmt(ctx, false_body_id)?; + self.visit_stmt(ctx, false_body_id.body)?; } return Ok(()) @@ -73,37 +73,38 @@ impl Visitor for PassRewriting { fn visit_while_stmt(&mut self, ctx: &mut Ctx, id: WhileStatementId) -> VisitorResult { let while_stmt = &ctx.heap[id]; let body_id = while_stmt.body; - return self.visit_block_stmt(ctx, body_id); + return self.visit_stmt(ctx, body_id); } fn visit_synchronous_stmt(&mut self, ctx: &mut Ctx, id: SynchronousStatementId) -> VisitorResult { let sync_stmt = &ctx.heap[id]; let body_id = sync_stmt.body; - return self.visit_block_stmt(ctx, body_id); + return self.visit_stmt(ctx, body_id); } // --- Visiting the select statement fn visit_select_stmt(&mut self, ctx: &mut Ctx, id: SelectStatementId) -> VisitorResult { - // Utility for the last stage of rewriting process + // Utility for the last stage of rewriting process. Note that caller + // still needs to point the end of the if-statement to the end of the + // replacement statement of the select statement. fn transform_select_case_code(ctx: &mut Ctx, select_id: SelectStatementId, case_index: usize, select_var_id: VariableId) -> (IfStatementId, EndIfStatementId) { // Retrieve statement IDs associated with case let case = &ctx.heap[select_id].cases[case_index]; let case_guard_id = case.guard; - let case_body_id = case.block; + let case_body_id = case.body; + let case_scope_id = case.scope; // Create the if-statement for the result of the select statement let compare_expr_id = create_ast_equality_comparison_expr(ctx, select_var_id, case_index as u64); - let (if_stmt_id, end_if_stmt_id) = create_ast_if_stmt(ctx, compare_expr_id.upcast(), case_body_id, None); - - // Modify body of case to link up to the surrounding statements - // correctly - let case_body = &mut ctx.heap[case_body_id]; - let case_end_body_id = case_body.end_block; - case_body.statements.insert(0, case_guard_id); + let true_case = IfStatementCase{ + body: case_guard_id, // which is linked up to the body + scope: case_scope_id, + }; + let (if_stmt_id, end_if_stmt_id) = create_ast_if_stmt(ctx, compare_expr_id.upcast(), true_case, None); - let case_end_body = &mut ctx.heap[case_end_body_id]; - case_end_body.next = end_if_stmt_id.upcast(); + // Link up body statement to end-if + set_ast_statement_next(ctx, case_body_id, end_if_stmt_id.upcast()); return (if_stmt_id, end_if_stmt_id) } @@ -112,7 +113,7 @@ impl Visitor for PassRewriting { // containing builtin runtime-calls. And to do so we create temporary // variables and move some other statements around. let select_stmt = &ctx.heap[id]; - let mut total_num_cases = select_stmt.cases.len(); + let total_num_cases = select_stmt.cases.len(); let mut total_num_ports = 0; let end_select_stmt_id = select_stmt.end_select; let end_select = &ctx.heap[end_select_stmt_id]; @@ -165,7 +166,7 @@ impl Visitor for PassRewriting { num_ports_expression_id.upcast() ]; - let call_expression_id = create_ast_call_expr(ctx, Method::SelectStart, arguments); + let call_expression_id = create_ast_call_expr(ctx, Method::SelectStart, &mut self.expression_buffer, arguments); let call_statement_id = create_ast_expression_stmt(ctx, call_expression_id.upcast()); transformed_stmts.push(call_statement_id.upcast()); @@ -192,7 +193,7 @@ impl Visitor for PassRewriting { ]; // Create runtime call, then store it - let runtime_call_expr_id = create_ast_call_expr(ctx, Method::SelectRegisterCasePort, runtime_call_arguments); + let runtime_call_expr_id = create_ast_call_expr(ctx, Method::SelectRegisterCasePort, &mut self.expression_buffer, runtime_call_arguments); let runtime_call_stmt_id = create_ast_expression_stmt(ctx, runtime_call_expr_id.upcast()); transformed_stmts.push(runtime_call_stmt_id.upcast()); @@ -208,7 +209,7 @@ impl Visitor for PassRewriting { locals.push(select_variable_id); { - let runtime_call_expr_id = create_ast_call_expr(ctx, Method::SelectWait, Vec::new()); + let runtime_call_expr_id = create_ast_call_expr(ctx, Method::SelectWait, &mut self.expression_buffer, Vec::new()); let variable_stmt_id = create_ast_variable_declaration_stmt(ctx, select_variable_id, runtime_call_expr_id.upcast()); transformed_stmts.push(variable_stmt_id.upcast().upcast()); } @@ -224,11 +225,7 @@ impl Visitor for PassRewriting { span: InputSpan::new(), statements: Vec::new(), end_block: EndBlockStatementId::new_invalid(), - scope_node: ScopeNode::new_invalid(), - first_unique_id_in_scope: -1, - next_unique_id_in_scope: -1, - locals: Vec::new(), - labels: Vec::new(), + scope: ScopeId::new_invalid(), next: StatementId::new_invalid(), }); let end_block_id = ctx.heap.alloc_end_block_statement(|this| EndBlockStatement{ @@ -248,7 +245,7 @@ impl Visitor for PassRewriting { for case_index in 1..total_num_cases { let (if_stmt_id, end_if_stmt_id) = transform_select_case_code(ctx, id, case_index, select_variable_id); let last_if_stmt = &mut ctx.heap[last_if_stmt_id]; - last_if_stmt.false_body = Some(if_stmt_id.upcast()); + // last_if_stmt.false_case = Some(if_stmt_id.upcast()); // TODO: // 1. Change scoping such that it is a separate datastructure with separate IDs @@ -365,7 +362,7 @@ fn create_ast_variable(ctx: &mut Ctx) -> VariableId { full_span: InputSpan::new(), }, identifier: Identifier::new_empty(InputSpan::new()), - relative_pos_in_block: -1, + relative_pos_in_parent: -1, unique_id_in_scope: -1, }); } @@ -381,7 +378,8 @@ fn create_ast_variable_expr(ctx: &mut Ctx, variable_id: VariableId) -> VariableE }); } -fn create_ast_call_expr(ctx: &mut Ctx, method: Method, arguments: Vec) -> CallExpressionId { +fn create_ast_call_expr(ctx: &mut Ctx, method: Method, buffer: &mut ScopedBuffer, arguments: Vec) -> CallExpressionId { + let expression_ids = buffer.start_section_initialized(&arguments); let call_expression_id = ctx.heap.alloc_call_expression(|this| CallExpression{ this, func_span: InputSpan::new(), @@ -397,9 +395,10 @@ fn create_ast_call_expr(ctx: &mut Ctx, method: Method, arguments: Vec) -> (IfStatementId, EndIfStatementId) { +fn create_ast_if_stmt(ctx: &mut Ctx, condition_expression_id: ExpressionId, true_case: IfStatementCase, false_case: Option) -> (IfStatementId, EndIfStatementId) { // Create if statement and the end-if statement let if_stmt_id = ctx.heap.alloc_if_statement(|this| IfStatement{ this, span: InputSpan::new(), test: condition_expression_id, - true_body, - false_body, + true_case, + false_case, end_if: EndIfStatementId::new_invalid() }); @@ -512,27 +511,76 @@ fn create_ast_if_stmt(ctx: &mut Ctx, condition_expression_id: ExpressionId, true let condition_expr = &mut ctx.heap[condition_expression_id]; *condition_expr.parent_mut() = ExpressionParent::If(if_stmt_id); - let true_body_stmt = &ctx.heap[true_body]; - let true_body_end_stmt = &mut ctx.heap[true_body_stmt.end_block]; - true_body_end_stmt.next = end_if_stmt_id.upcast(); - - if let Some(false_body) = false_body { - let false_body_stmt = &ctx.heap[false_body]; - let false_body_end_stmt = &mut ctx.heap[false_body_stmt.end_block]; - false_body_end_stmt.next = end_if_stmt_id.upcast(); - } - return (if_stmt_id, end_if_stmt_id); } -fn set_ast_if_statement_false_body(ctx: &mut Ctx, if_statement_id: IfStatementId, end_if_statement_id: EndIfStatementId, false_body: BlockStatementId) { +fn set_ast_if_statement_false_body(ctx: &mut Ctx, if_statement_id: IfStatementId, end_if_statement_id: EndIfStatementId, false_body_id: StatementId) { // Point if-statement to "false body" + todo!("set scopes"); let if_stmt = &mut ctx.heap[if_statement_id]; - debug_assert!(if_stmt.false_body.is_none()); // simplifies logic, not necessary - if_stmt.false_body = Some(false_body); + debug_assert!(if_stmt.false_case.is_none()); // simplifies logic, not necessary + if_stmt.false_case = Some(IfStatementCase{ + body: false_body_id, + scope: ScopeId::new_invalid(), + }); // Point end of false body to the end of the if statement - let false_body_stmt = &ctx.heap[false_body]; - let false_body_end_stmt = &mut ctx.heap[false_body_stmt.end_block]; - false_body_end_stmt.next = end_if_statement_id.upcast(); + set_ast_statement_next(ctx, false_body_id, end_if_statement_id.upcast()); +} + +/// Sets the specified AST statement's control flow such that it will be +/// followed by the target statement. This may seem obvious, but may imply that +/// a statement associated with, but different from, the source statement is +/// modified. +fn set_ast_statement_next(ctx: &mut Ctx, source_stmt_id: StatementId, target_stmt_id: StatementId) { + let source_stmt = &mut ctx.heap[source_stmt_id]; + match source_stmt { + Statement::Block(stmt) => { + let end_id = stmt.end_block; + ctx.heap[end_id].next = target_stmt_id + }, + Statement::EndBlock(stmt) => stmt.next = target_stmt_id, + Statement::Local(stmt) => { + match stmt { + LocalStatement::Memory(stmt) => stmt.next = target_stmt_id, + LocalStatement::Channel(stmt) => stmt.next = target_stmt_id, + } + }, + Statement::Labeled(stmt) => { + let body_id = stmt.body; + set_ast_statement_next(ctx, body_id, target_stmt_id); + }, + Statement::If(stmt) => { + let end_id = stmt.end_if; + ctx.heap[end_id].next = target_stmt_id; + }, + Statement::EndIf(stmt) => stmt.next = target_stmt_id, + Statement::While(stmt) => { + let end_id = stmt.end_while; + ctx.heap[end_id].next = target_stmt_id; + }, + Statement::EndWhile(stmt) => stmt.next = target_stmt_id, + + Statement::Break(_stmt) => {}, + Statement::Continue(_stmt) => {}, + Statement::Synchronous(stmt) => { + let end_id = stmt.end_sync; + ctx.heap[end_id].next = target_stmt_id; + }, + Statement::EndSynchronous(stmt) => { + stmt.next = target_stmt_id; + }, + Statement::Fork(_) | Statement::EndFork(_) => { + todo!("remove fork from language"); + }, + Statement::Select(stmt) => { + let end_id = stmt.end_select; + ctx.heap[end_id].next = target_stmt_id; + }, + Statement::EndSelect(stmt) => stmt.next = target_stmt_id, + Statement::Return(_stmt) => {}, + Statement::Goto(_stmt) => {}, + Statement::New(stmt) => stmt.next = target_stmt_id, + Statement::Expression(stmt) => stmt.next = target_stmt_id, + } } \ No newline at end of file diff --git a/src/protocol/parser/pass_typing.rs b/src/protocol/parser/pass_typing.rs index e31e13c896efe7558d80c95de258c845868faeb9..dace9be4d17124fe5a046dd1368aafef815f140f 100644 --- a/src/protocol/parser/pass_typing.rs +++ b/src/protocol/parser/pass_typing.rs @@ -56,7 +56,8 @@ use crate::protocol::parser::ModuleCompilationPhase; use crate::protocol::parser::type_table::*; use crate::protocol::parser::token_parsing::*; use super::visitor::{ - BUFFER_INIT_CAPACITY, + BUFFER_INIT_CAP_LARGE, + BUFFER_INIT_CAP_SMALL, Ctx, Visitor, VisitorResult @@ -921,10 +922,10 @@ impl PassTyping { reserved_idx: -1, definition_type: DefinitionType::Function(FunctionDefinitionId::new_invalid()), poly_vars: Vec::new(), - var_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAPACITY), - expr_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAPACITY), - stmt_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAPACITY), - bool_buffer: ScopedBuffer::with_capacity(16), + var_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_LARGE), + expr_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_LARGE), + stmt_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_LARGE), + bool_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL), var_types: HashMap::new(), expr_types: Vec::new(), extra_data: Vec::new(), @@ -1130,14 +1131,14 @@ impl Visitor for PassTyping { fn visit_if_stmt(&mut self, ctx: &mut Ctx, id: IfStatementId) -> VisitorResult { let if_stmt = &ctx.heap[id]; - let true_body_id = if_stmt.true_body; - let false_body_id = if_stmt.false_body; + let true_body_case = if_stmt.true_case; + let false_body_case = if_stmt.false_case; let test_expr_id = if_stmt.test; self.visit_expr(ctx, test_expr_id)?; - self.visit_block_stmt(ctx, true_body_id)?; - if let Some(false_body_id) = false_body_id { - self.visit_block_stmt(ctx, false_body_id)?; + self.visit_stmt(ctx, true_body_case.body)?; + if let Some(false_body_case) = false_body_case { + self.visit_stmt(ctx, false_body_case.body)?; } Ok(()) @@ -1150,7 +1151,7 @@ impl Visitor for PassTyping { let test_expr_id = while_stmt.test; self.visit_expr(ctx, test_expr_id)?; - self.visit_block_stmt(ctx, body_id)?; + self.visit_stmt(ctx, body_id)?; Ok(()) } @@ -1159,7 +1160,7 @@ impl Visitor for PassTyping { let sync_stmt = &ctx.heap[id]; let body_id = sync_stmt.body; - self.visit_block_stmt(ctx, body_id) + self.visit_stmt(ctx, body_id) } fn visit_fork_stmt(&mut self, ctx: &mut Ctx, id: ForkStatementId) -> VisitorResult { @@ -1167,9 +1168,9 @@ impl Visitor for PassTyping { let left_body_id = fork_stmt.left_body; let right_body_id = fork_stmt.right_body; - self.visit_block_stmt(ctx, left_body_id)?; + self.visit_stmt(ctx, left_body_id)?; if let Some(right_body_id) = right_body_id { - self.visit_block_stmt(ctx, right_body_id)?; + self.visit_stmt(ctx, right_body_id)?; } Ok(()) @@ -1183,7 +1184,7 @@ impl Visitor for PassTyping { for case in &select_stmt.cases { section.push(case.guard); - section.push(case.block.upcast()); + section.push(case.body); } for case_index in 0..num_cases { @@ -3329,8 +3330,7 @@ impl PassTyping { EP::Return(_) => // Must match the return type of the function if let DefinitionType::Function(func_id) = self.definition_type { - debug_assert_eq!(ctx.heap[func_id].return_types.len(), 1); - let returned = &ctx.heap[func_id].return_types[0]; + let returned = &ctx.heap[func_id].return_type; self.determine_inference_type_from_parser_type_elements(&returned.elements, true) } else { // Cannot happen: definition always set upon body traversal @@ -3413,7 +3413,7 @@ impl PassTyping { }, Definition::Function(definition) => { debug_assert_eq!(poly_args.len(), definition.poly_vars.len()); - (&definition.parameters, Some(&definition.return_types)) + (&definition.parameters, Some(&definition.return_type)) }, Definition::Struct(_) | Definition::Enum(_) | Definition::Union(_) => { unreachable!("insert_initial_call_polymorph data for non-procedure type"); @@ -3432,8 +3432,6 @@ impl PassTyping { InferenceType::new(false, true, vec![InferenceTypePart::Void]) }, Some(returned) => { - debug_assert_eq!(returned.len(), 1); // TODO: @ReturnTypes - let returned = &returned[0]; self.determine_inference_type_from_parser_type_elements(&returned.elements, false) } }; diff --git a/src/protocol/parser/pass_validation_linking.rs b/src/protocol/parser/pass_validation_linking.rs index 2d3292c0b00cbd158c35f10c0cbcd4211c2beea9..954bcc968214684eadc3f67b972d00aa463b067a 100644 --- a/src/protocol/parser/pass_validation_linking.rs +++ b/src/protocol/parser/pass_validation_linking.rs @@ -42,7 +42,8 @@ use crate::protocol::parser::symbol_table::*; use crate::protocol::parser::type_table::*; use super::visitor::{ - BUFFER_INIT_CAPACITY, + BUFFER_INIT_CAP_SMALL, + BUFFER_INIT_CAP_LARGE, Ctx, Visitor, VisitorResult @@ -72,7 +73,7 @@ impl DefinitionType { struct ControlFlowStatement { in_sync: SynchronousStatementId, in_while: WhileStatementId, - in_scope: Scope, + in_scope: ScopeId, statement: StatementId, // of 'break', 'continue' or 'goto' } @@ -101,7 +102,7 @@ pub(crate) struct PassValidationLinking { in_binding_expr_lhs: bool, // Traversal state, current scope (which can be used to find the parent // scope) and the definition variant we are considering. - cur_scope: Scope, + cur_scope: ScopeId, def_type: DefinitionType, // "Trailing" traversal state, set be child/prev stmt/expr used by next one prev_stmt: StatementId, @@ -111,7 +112,7 @@ pub(crate) struct PassValidationLinking { // used for the error's position must_be_assignable: Option, // Keeping track of relative positions and unique IDs. - relative_pos_in_block: i32, // of statements: to determine when variables are visible + relative_pos_in_parent: i32, // of statements: to determine when variables are visible next_expr_index: i32, // to arrive at a unique ID for all expressions within a definition // Control flow statements that require label resolving control_flow_stmts: Vec, @@ -122,6 +123,7 @@ pub(crate) struct PassValidationLinking { definition_buffer: ScopedBuffer, statement_buffer: ScopedBuffer, expression_buffer: ScopedBuffer, + scope_buffer: ScopedBuffer, } impl PassValidationLinking { @@ -134,18 +136,19 @@ impl PassValidationLinking { in_test_expr: StatementId::new_invalid(), in_binding_expr: BindingExpressionId::new_invalid(), in_binding_expr_lhs: false, - cur_scope: Scope::new_invalid(), + cur_scope: ScopeId::new_invalid(), prev_stmt: StatementId::new_invalid(), expr_parent: ExpressionParent::None, def_type: DefinitionType::Function(FunctionDefinitionId::new_invalid()), must_be_assignable: None, - relative_pos_in_block: 0, + relative_pos_in_parent: 0, next_expr_index: 0, - control_flow_stmts: Vec::with_capacity(32), - variable_buffer: ScopedBuffer::with_capacity(128), - definition_buffer: ScopedBuffer::with_capacity(128), - statement_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAPACITY), - expression_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAPACITY), + control_flow_stmts: Vec::with_capacity(BUFFER_INIT_CAP_SMALL), + variable_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL), + definition_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL), + statement_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_LARGE), + expression_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_LARGE), + scope_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL), } } @@ -156,12 +159,12 @@ impl PassValidationLinking { self.in_test_expr = StatementId::new_invalid(); self.in_binding_expr = BindingExpressionId::new_invalid(); self.in_binding_expr_lhs = false; - self.cur_scope = Scope::new_invalid(); + self.cur_scope = ScopeId::new_invalid(); self.def_type = DefinitionType::Function(FunctionDefinitionId::new_invalid()); self.prev_stmt = StatementId::new_invalid(); self.expr_parent = ExpressionParent::None; self.must_be_assignable = None; - self.relative_pos_in_block = 0; + self.relative_pos_in_parent = 0; self.next_expr_index = 0; self.control_flow_stmts.clear(); } @@ -210,27 +213,33 @@ impl Visitor for PassValidationLinking { ComponentVariant::Primitive => DefinitionType::Primitive(id), ComponentVariant::Composite => DefinitionType::Composite(id), }; - self.cur_scope = Scope::Definition(id.upcast()); self.expr_parent = ExpressionParent::None; // Visit parameters and assign a unique scope ID + let old_scope = self.push_scope(ctx, ScopeAssociation::Definition(id.upcast())); + let definition = &ctx.heap[id]; let body_id = definition.body; let section = self.variable_buffer.start_section_initialized(&definition.parameters); for variable_idx in 0..section.len() { let variable_id = section[variable_idx]; - let variable = &mut ctx.heap[variable_id]; - variable.unique_id_in_scope = variable_idx as i32; + self.checked_at_single_scope_add_local(ctx, self.cur_scope, variable_idx as i32, variable_id)?; } + self.relative_pos_in_parent = section.len() as i32; + section.forget(); // Visit statements in component body self.visit_block_stmt(ctx, body_id)?; + self.pop_scope(old_scope); // Assign total number of expressions and assign an in-block unique ID // to each of the locals in the procedure. - ctx.heap[id].num_expressions_in_body = self.next_expr_index; - self.visit_definition_and_assign_local_ids(ctx, id.upcast()); + let definition = &mut ctx.heap[id]; + let definition_scope = definition.scope; + definition.num_expressions_in_body = self.next_expr_index; + + self.visit_scope_and_assign_local_ids(ctx, definition_scope, 0); self.resolve_pending_control_flow_targets(ctx)?; Ok(()) @@ -241,27 +250,30 @@ impl Visitor for PassValidationLinking { // Set internal statement indices self.def_type = DefinitionType::Function(id); - self.cur_scope = Scope::Definition(id.upcast()); self.expr_parent = ExpressionParent::None; // Visit parameters and assign a unique scope ID + let old_scope = self.push_scope(ctx, ScopeAssociation::Definition(id.upcast())); + let definition = &ctx.heap[id]; let body_id = definition.body; let section = self.variable_buffer.start_section_initialized(&definition.parameters); for variable_idx in 0..section.len() { let variable_id = section[variable_idx]; - let variable = &mut ctx.heap[variable_id]; - variable.unique_id_in_scope = variable_idx as i32; + self.checked_at_single_scope_add_local(ctx, self.cur_scope, variable_idx as i32, variable_id)?; } section.forget(); // Visit statements in function body self.visit_block_stmt(ctx, body_id)?; + self.pop_scope(old_scope); // Assign total number of expressions and assign an in-block unique ID // to each of the locals in the procedure. - ctx.heap[id].num_expressions_in_body = self.next_expr_index; - self.visit_definition_and_assign_local_ids(ctx, id.upcast()); + let definition = &mut ctx.heap[id]; + let definition_scope = definition.scope; + definition.num_expressions_in_body = self.next_expr_index; + self.visit_scope_and_assign_local_ids(ctx, definition_scope, 0); self.resolve_pending_control_flow_targets(ctx)?; Ok(()) @@ -272,25 +284,24 @@ impl Visitor for PassValidationLinking { //-------------------------------------------------------------------------- fn visit_block_stmt(&mut self, ctx: &mut Ctx, id: BlockStatementId) -> VisitorResult { - let old_scope = self.push_statement_scope(ctx, Scope::Regular(id)); - - // Set end of block + // Get end of block let block_stmt = &ctx.heap[id]; let end_block_id = block_stmt.end_block; // Traverse statements in block let statement_section = self.statement_buffer.start_section_initialized(&block_stmt.statements); + let old_scope = self.push_scope(ctx, ScopeAssociation::Block(id)); assign_and_replace_next_stmt!(self, ctx, id.upcast()); for stmt_idx in 0..statement_section.len() { - self.relative_pos_in_block = stmt_idx as i32; + self.relative_pos_in_parent = stmt_idx as i32; self.visit_stmt(ctx, statement_section[stmt_idx])?; } statement_section.forget(); assign_and_replace_next_stmt!(self, ctx, end_block_id.upcast()); - self.pop_statement_scope(old_scope); + self.pop_scope(old_scope); Ok(()) } @@ -299,7 +310,7 @@ impl Visitor for PassValidationLinking { let expr_id = stmt.initial_expr; let variable_id = stmt.variable; - self.checked_add_local(ctx, self.cur_scope, self.relative_pos_in_block, variable_id)?; + self.checked_add_local(ctx, self.cur_scope, self.relative_pos_in_parent, variable_id)?; assign_and_replace_next_stmt!(self, ctx, id.upcast().upcast()); debug_assert_eq!(self.expr_parent, ExpressionParent::None); @@ -315,8 +326,8 @@ impl Visitor for PassValidationLinking { let from_id = stmt.from; let to_id = stmt.to; - self.checked_add_local(ctx, self.cur_scope, self.relative_pos_in_block, from_id)?; - self.checked_add_local(ctx, self.cur_scope, self.relative_pos_in_block, to_id)?; + self.checked_add_local(ctx, self.cur_scope, self.relative_pos_in_parent, from_id)?; + self.checked_add_local(ctx, self.cur_scope, self.relative_pos_in_parent, to_id)?; assign_and_replace_next_stmt!(self, ctx, id.upcast().upcast()); Ok(()) @@ -326,7 +337,7 @@ impl Visitor for PassValidationLinking { let stmt = &ctx.heap[id]; let body_id = stmt.body; - self.checked_add_label(ctx, self.relative_pos_in_block, self.in_sync, id)?; + self.checked_add_label(ctx, self.relative_pos_in_parent, self.in_sync, id)?; self.visit_stmt(ctx, body_id)?; Ok(()) @@ -336,8 +347,8 @@ impl Visitor for PassValidationLinking { let if_stmt = &ctx.heap[id]; let end_if_id = if_stmt.end_if; let test_expr_id = if_stmt.test; - let true_stmt_id = if_stmt.true_body; - let false_stmt_id = if_stmt.false_body; + let true_case = if_stmt.true_case; + let false_case = if_stmt.false_case; // Visit test expression debug_assert_eq!(self.expr_parent, ExpressionParent::None); @@ -354,11 +365,15 @@ impl Visitor for PassValidationLinking { // test expression, not on if-statement itself. Hence the if statement // does not have a static subsequent statement. assign_then_erase_next_stmt!(self, ctx, id.upcast()); - self.visit_block_stmt(ctx, true_stmt_id)?; + let old_scope = self.push_scope(ctx, ScopeAssociation::If(id, true)); + self.visit_stmt(ctx, true_case.body)?; + self.pop_scope(old_scope); assign_then_erase_next_stmt!(self, ctx, end_if_id.upcast()); - if let Some(false_id) = false_stmt_id { - self.visit_block_stmt(ctx, false_id)?; + if let Some(false_case) = false_case { + let old_scope = self.push_scope(ctx, ScopeAssociation::If(id, false)); + self.visit_stmt(ctx, false_case.body)?; + self.pop_scope(old_scope); assign_then_erase_next_stmt!(self, ctx, end_if_id.upcast()); } @@ -387,7 +402,9 @@ impl Visitor for PassValidationLinking { assign_then_erase_next_stmt!(self, ctx, id.upcast()); self.expr_parent = ExpressionParent::None; - self.visit_block_stmt(ctx, body_stmt_id)?; + let old_scope = self.push_scope(ctx, ScopeAssociation::While(id)); + self.visit_stmt(ctx, body_stmt_id)?; + self.pop_scope(old_scope); self.in_while = old_while; // Link final entry in while's block statement back to the while. The @@ -454,9 +471,9 @@ impl Visitor for PassValidationLinking { let sync_body = ctx.heap[id].body; debug_assert!(self.in_sync.is_invalid()); self.in_sync = id; - let old_scope = self.push_statement_scope(ctx, Scope::Synchronous(id, sync_body)); - self.visit_block_stmt(ctx, sync_body)?; - self.pop_statement_scope(old_scope); + let old_scope = self.push_scope(ctx, ScopeAssociation::Synchronous(id)); + self.visit_stmt(ctx, sync_body)?; + self.pop_scope(old_scope); assign_and_replace_next_stmt!(self, ctx, end_sync_id.upcast()); self.in_sync = SynchronousStatementId::new_invalid(); @@ -482,11 +499,11 @@ impl Visitor for PassValidationLinking { // does not have a single static subsequent statement. It forks and then // each fork has a different next statement. assign_then_erase_next_stmt!(self, ctx, id.upcast()); - self.visit_block_stmt(ctx, left_body_id)?; + self.visit_stmt(ctx, left_body_id)?; assign_then_erase_next_stmt!(self, ctx, end_fork_id.upcast()); if let Some(right_body_id) = right_body_id { - self.visit_block_stmt(ctx, right_body_id)?; + self.visit_stmt(ctx, right_body_id)?; assign_then_erase_next_stmt!(self, ctx, end_fork_id.upcast()); } @@ -517,10 +534,9 @@ impl Visitor for PassValidationLinking { let mut case_stmt_ids = self.statement_buffer.start_section(); let num_cases = select_stmt.cases.len(); for case in &select_stmt.cases { - // Note: we add both to the buffer, retrieve them later in indexed - // fashion + // We add them in pairs, so the subsequent for-loop retrieves in pairs case_stmt_ids.push(case.guard); - case_stmt_ids.push(case.block.upcast()); + case_stmt_ids.push(case.body); } assign_then_erase_next_stmt!(self, ctx, id.upcast()); @@ -528,14 +544,12 @@ impl Visitor for PassValidationLinking { for idx in 0..num_cases { let base_idx = 2 * idx; let guard_id = case_stmt_ids[base_idx ]; - let arm_block_id = case_stmt_ids[base_idx + 1]; - debug_assert_eq!(ctx.heap[arm_block_id].as_block().this.upcast(), arm_block_id); // backwards way of saying arm_block_id is a BlockStatementId - let arm_block_id = BlockStatementId(arm_block_id); + let case_body_id = case_stmt_ids[base_idx + 1]; // The guard statement ends up belonging to the block statement // following the arm. The reason we parse it separately is to // extract all of the "get" calls. - let old_scope = self.push_statement_scope(ctx, Scope::Regular(arm_block_id)); + let old_scope = self.push_scope(ctx, ScopeAssociation::SelectCase(id, idx as u32)); // Visit the guard of this arm debug_assert!(self.in_select_guard.is_invalid()); @@ -545,8 +559,8 @@ impl Visitor for PassValidationLinking { self.in_select_guard = SelectStatementId::new_invalid(); // Visit the code associated with the guard - self.visit_block_stmt(ctx, arm_block_id)?; - self.pop_statement_scope(old_scope); + self.visit_stmt(ctx, case_body_id)?; + self.pop_scope(old_scope); // Link up last statement in block to EndSelect assign_then_erase_next_stmt!(self, ctx, end_select_id.upcast()); @@ -1327,7 +1341,7 @@ impl Visitor for PassValidationLinking { // Otherwise try to find it if variable_id.is_none() { - variable_id = self.find_variable(ctx, self.relative_pos_in_block, &var_expr.identifier); + variable_id = self.find_variable(ctx, self.relative_pos_in_parent, &var_expr.identifier); } // Otherwise try to see if is a variable introduced by a binding expr @@ -1388,8 +1402,8 @@ impl Visitor for PassValidationLinking { // By now we know that this is a valid binding expression. Given // that a binding expression must be nested under an if/while - // statement, we now add the variable to the (implicit) block - // statement following the if/while statement. + // statement, we now add the variable to the scope associated with + // that statement. let bound_identifier = var_expr.identifier.clone(); let bound_variable_id = ctx.heap.alloc_variable(|this| Variable { this, @@ -1402,17 +1416,17 @@ impl Visitor for PassValidationLinking { full_span: bound_identifier.span }, identifier: bound_identifier, - relative_pos_in_block: 0, + relative_pos_in_parent: 0, unique_id_in_scope: -1, }); - let body_stmt_id = match &ctx.heap[self.in_test_expr] { - Statement::If(stmt) => stmt.true_body, - Statement::While(stmt) => stmt.body, + let scope_id = match &ctx.heap[self.in_test_expr] { + Statement::If(stmt) => stmt.true_case.scope, + Statement::While(stmt) => stmt.scope, _ => unreachable!(), }; - let body_scope = Scope::Regular(body_stmt_id); - self.checked_at_single_scope_add_local(ctx, body_scope, -1, bound_variable_id)?; // add at -1 such that first statement can access + + self.checked_at_single_scope_add_local(ctx, scope_id, -1, bound_variable_id)?; // add at -1 such that first statement can find the variable if needed is_binding_target = true; bound_variable_id @@ -1439,123 +1453,139 @@ impl PassValidationLinking { /// sync statement or select statement's arm) then we won't do anything. /// In all cases the caller must call `pop_statement_scope` with the scope /// and relative scope position returned by this function. - fn push_statement_scope(&mut self, ctx: &mut Ctx, new_scope: Scope) -> (Scope, i32) { - let old_scope = self.cur_scope.clone(); - debug_assert!(new_scope.is_block()); // never call for Definition scope - let is_new_block = if old_scope.is_block() { - old_scope.to_block() != new_scope.to_block() - } else { - true + fn push_scope(&mut self, ctx: &mut Ctx, association: ScopeAssociation) -> (ScopeId, i32) { + // Create new scope and assign as ScopeId to the specified associated + // statement. + let is_first_scope = match association { + ScopeAssociation::Definition(_) => true, + _ => false, }; - if !is_new_block { - // No need to push, but still return old scope, we pretend like we - // replaced it. - debug_assert!(!ctx.heap[new_scope.to_block()].scope_node.parent.is_invalid()); - return (old_scope, self.relative_pos_in_block); - } + let old_scope_id = self.cur_scope.clone(); + let new_scope_id = ctx.heap.alloc_scope(|this| Scope{ + this, + parent: if is_first_scope { None } else { Some(old_scope_id) }, + nested: Vec::new(), + association, + variables: Vec::new(), + labels: Vec::new(), + relative_pos_in_parent: self.relative_pos_in_parent, + first_unique_id_in_scope: -1, + next_unique_id_in_scope: -1, + }); - // This is a new block, so link it up - if old_scope.is_block() { - let parent_block = &mut ctx.heap[old_scope.to_block()]; - parent_block.scope_node.nested.push(new_scope); + match association { + ScopeAssociation::Definition(definition_id) => { + let def = &mut ctx.heap[definition_id]; + match def { + Definition::Function(def) => def.scope = new_scope_id, + Definition::Component(def) => def.scope = new_scope_id, + _ => unreachable!(), + } + }, + ScopeAssociation::Block(stmt_id) => { + ctx.heap[stmt_id].scope = new_scope_id; + }, + ScopeAssociation::If(stmt_id, in_true_body) => { + let stmt = &mut ctx.heap[stmt_id]; + if in_true_body { + stmt.true_case.scope = new_scope_id; + } else { + let false_body = stmt.false_case.as_mut().unwrap(); + false_body.scope = new_scope_id; + } + }, + ScopeAssociation::While(stmt_id) => { + ctx.heap[stmt_id].scope = new_scope_id; + }, + ScopeAssociation::Synchronous(stmt_id) => { + ctx.heap[stmt_id].scope = new_scope_id; + }, + ScopeAssociation::SelectCase(stmt_id, case_index) => { + let select_stmt = &mut ctx.heap[stmt_id]; + select_stmt.cases[case_index as usize].scope = new_scope_id; + } } - self.cur_scope = new_scope; - - let cur_block = &mut ctx.heap[new_scope.to_block()]; - cur_block.scope_node.parent = old_scope; - cur_block.scope_node.relative_pos_in_parent = self.relative_pos_in_block; + // Link up scopes + if !is_first_scope { + let old_scope = &mut ctx.heap[old_scope_id]; + old_scope.nested.push(new_scope_id); + } - let old_relative_pos = self.relative_pos_in_block; - self.relative_pos_in_block = -1; + // Set as current traversal scope, then return old scope + self.cur_scope = new_scope_id; - return (old_scope, old_relative_pos) + let old_relative_pos = self.relative_pos_in_parent; + self.relative_pos_in_parent = 0; + return (old_scope_id, old_relative_pos) } - fn pop_statement_scope(&mut self, scope_to_restore: (Scope, i32)) { + fn pop_scope(&mut self, scope_to_restore: (ScopeId, i32)) { self.cur_scope = scope_to_restore.0; - self.relative_pos_in_block = scope_to_restore.1; + self.relative_pos_in_parent = scope_to_restore.1; } - fn visit_definition_and_assign_local_ids(&mut self, ctx: &mut Ctx, definition_id: DefinitionId) { - let mut var_counter = 0; - - // Set IDs on parameters - let (param_section, body_id) = match &ctx.heap[definition_id] { - Definition::Function(func_def) => ( - self.variable_buffer.start_section_initialized(&func_def.parameters), - func_def.body - ), - Definition::Component(comp_def) => ( - self.variable_buffer.start_section_initialized(&comp_def.parameters), - comp_def.body - ), - _ => unreachable!(), - } ; - - for var_id in param_section.iter_copied() { - let var = &mut ctx.heap[var_id]; - var.unique_id_in_scope = var_counter; - var_counter += 1; - } - - param_section.forget(); + fn visit_scope_and_assign_local_ids(&mut self, ctx: &mut Ctx, scope_id: ScopeId, mut variable_counter: i32) { + let scope = &mut ctx.heap[scope_id]; + scope.first_unique_id_in_scope = variable_counter; - // Recurse into body - self.visit_block_and_assign_local_ids(ctx, body_id, var_counter); - } + let variable_section = self.variable_buffer.start_section_initialized(&scope.variables); + let child_scope_section = self.scope_buffer.start_section_initialized(&scope.nested); - fn visit_block_and_assign_local_ids(&mut self, ctx: &mut Ctx, block_id: BlockStatementId, mut var_counter: i32) { - let block_stmt = &mut ctx.heap[block_id]; - block_stmt.first_unique_id_in_scope = var_counter; + let mut variable_index = 0; + let mut child_scope_index = 0; - let var_section = self.variable_buffer.start_section_initialized(&block_stmt.locals); - let mut scope_section = self.statement_buffer.start_section(); - for child_scope in &block_stmt.scope_node.nested { - debug_assert!(child_scope.is_block(), "found a child scope that is not a block statement"); - scope_section.push(child_scope.to_block().upcast()); - } - - let mut var_idx = 0; - let mut scope_idx = 0; - while var_idx < var_section.len() || scope_idx < scope_section.len() { - let relative_var_pos = if var_idx < var_section.len() { - ctx.heap[var_section[var_idx]].relative_pos_in_block + loop { + // Determine relative positions of variable and scope to determine + // which one occurs first within the current scope. + let variable_relative_pos; + if variable_index < variable_section.len() { + let variable_id = variable_section[variable_index]; + let variable = &ctx.heap[variable_id]; + variable_relative_pos = variable.relative_pos_in_parent; } else { - i32::MAX - }; + variable_relative_pos = i32::MAX; + } - let relative_scope_pos = if scope_idx < scope_section.len() { - ctx.heap[scope_section[scope_idx]].as_block().scope_node.relative_pos_in_parent + let child_scope_relative_pos; + if child_scope_index < child_scope_section.len() { + let child_scope_id = child_scope_section[child_scope_index]; + let child_scope = &ctx.heap[child_scope_id]; + child_scope_relative_pos = child_scope.relative_pos_in_parent; } else { - i32::MAX - }; + child_scope_relative_pos = i32::MAX; + } - debug_assert!(!(relative_var_pos == i32::MAX && relative_scope_pos == i32::MAX)); + if variable_relative_pos == i32::MAX && child_scope_relative_pos == i32::MAX { + // Done, no more elements in the scope to consider + break; + } - // In certain cases the relative variable position is the same as - // the scope position (insertion of binding variables). In that case - // the variable should be treated first - if relative_var_pos <= relative_scope_pos { - let var = &mut ctx.heap[var_section[var_idx]]; - var.unique_id_in_scope = var_counter; - var_counter += 1; - var_idx += 1; + // Label the variable/scope, whichever comes first. When dealing + // with binding variables, where both variable and scope are + // considered to have the same position in the scope, we treat the + // variable first. + // TODO: Think about this some more, isn't it correct that we + // consider it part of the fresh scope? So we'll have to deal with + // the scope first, and add the variable to that scope? + if variable_relative_pos <= child_scope_relative_pos { + let variable = &mut ctx.heap[variable_section[variable_index]]; + variable.unique_id_in_scope = variable_counter; + variable_counter += 1; + variable_index += 1; } else { - // Boy oh boy - let block_id = ctx.heap[scope_section[scope_idx]].as_block().this; - self.visit_block_and_assign_local_ids(ctx, block_id, var_counter); - scope_idx += 1; + let child_scope_id = child_scope_section[child_scope_index]; + self.visit_scope_and_assign_local_ids(ctx, child_scope_id, variable_counter); + child_scope_index += 1; } } - var_section.forget(); - scope_section.forget(); + variable_section.forget(); + child_scope_section.forget(); - // Done assigning all IDs, assign the last ID to the block statement scope - let block_stmt = &mut ctx.heap[block_id]; - block_stmt.next_unique_id_in_scope = var_counter; + let scope = &mut ctx.heap[scope_id]; + scope.next_unique_id_in_scope = variable_counter; } fn resolve_pending_control_flow_targets(&mut self, ctx: &mut Ctx) -> Result<(), ParseError> { @@ -1614,76 +1644,50 @@ impl PassValidationLinking { /// Adds a local variable to the current scope. It will also annotate the /// `Local` in the AST with its relative position in the block. - fn checked_add_local(&mut self, ctx: &mut Ctx, target_scope: Scope, target_relative_pos: i32, id: VariableId) -> Result<(), ParseError> { - debug_assert!(target_scope.is_block()); + fn checked_add_local(&mut self, ctx: &mut Ctx, target_scope_id: ScopeId, target_relative_pos: i32, id: VariableId) -> Result<(), ParseError> { let local = &ctx.heap[id]; // We immediately go to the parent scope. We check the target scope // in the call at the end. That is also where we check for collisions // with symbols. - let block = &ctx.heap[target_scope.to_block()]; - let mut scope = block.scope_node.parent; - let mut cur_relative_pos = block.scope_node.relative_pos_in_parent; - loop { - if let Scope::Definition(definition_id) = scope { - // At outer scope, check parameters of function/component - for parameter_id in ctx.heap[definition_id].parameters() { - let parameter = &ctx.heap[*parameter_id]; - if local.identifier == parameter.identifier { - return Err( - ParseError::new_error_str_at_span( - &ctx.module().source, local.identifier.span, "Local variable name conflicts with parameter" - ).with_info_str_at_span( - &ctx.module().source, parameter.identifier.span, "Parameter definition is found here" - ) - ); - } - } - - // No collisions - break; - } - - // If here then the parent scope is a block scope - let block = &ctx.heap[scope.to_block()]; - - for other_local_id in &block.locals { - let other_local = &ctx.heap[*other_local_id]; - // Position check in case another variable with the same name - // is defined in a higher-level scope, but later than the scope - // in which the current variable resides. - if local.this != *other_local_id && - cur_relative_pos >= other_local.relative_pos_in_block && - local.identifier == other_local.identifier { - // Collision within this scope + let mut scope = &ctx.heap[target_scope_id]; + let mut cur_relative_pos = scope.relative_pos_in_parent; + while let Some(scope_parent_id) = scope.parent { + scope = &ctx.heap[scope_parent_id]; + + // Check for collisions + for variable_id in scope.variables.iter().copied() { + let variable = &ctx.heap[variable_id]; + if variable.identifier == variable.identifier && + variable.this != id && + cur_relative_pos >= variable.relative_pos_in_parent { return Err( ParseError::new_error_str_at_span( &ctx.module().source, local.identifier.span, "Local variable name conflicts with another variable" ).with_info_str_at_span( - &ctx.module().source, other_local.identifier.span, "Previous variable is found here" + &ctx.module().source, variable.identifier.span, "Previous variable is found here" ) ); } } - scope = block.scope_node.parent; - cur_relative_pos = block.scope_node.relative_pos_in_parent; + cur_relative_pos = scope.relative_pos_in_parent; } // No collisions in any of the parent scope, attempt to add to scope - self.checked_at_single_scope_add_local(ctx, target_scope, target_relative_pos, id) + self.checked_at_single_scope_add_local(ctx, target_scope_id, target_relative_pos, id) } /// Adds a local variable to the specified scope. Will check the specified /// scope for variable conflicts and the symbol table for global conflicts. /// Will NOT check parent scopes of the specified scope. fn checked_at_single_scope_add_local( - &mut self, ctx: &mut Ctx, scope: Scope, relative_pos: i32, id: VariableId + &mut self, ctx: &mut Ctx, scope_id: ScopeId, relative_pos: i32, new_variable_id: VariableId ) -> Result<(), ParseError> { // Check the symbol table for conflicts { let cur_scope = SymbolScope::Definition(self.def_type.definition_id()); - let ident = &ctx.heap[id].identifier; + let ident = &ctx.heap[new_variable_id].identifier; if let Some(symbol) = ctx.symbols.get_symbol_by_name(cur_scope, &ident.value.as_bytes()) { return Err(ParseError::new_error_str_at_span( &ctx.module().source, ident.span, @@ -1695,32 +1699,31 @@ impl PassValidationLinking { } // Check the specified scope for conflicts - let local = &ctx.heap[id]; + let new_variable = &ctx.heap[new_variable_id]; + let scope = &ctx.heap[scope_id]; - debug_assert!(scope.is_block()); - let block = &ctx.heap[scope.to_block()]; - for other_local_id in &block.locals { - let other_local = &ctx.heap[*other_local_id]; - if local.this != other_local.this && + for variable_id in scope.variables.iter().copied() { + let old_variable = &ctx.heap[variable_id]; + if new_variable.this != old_variable.this && // relative_pos >= other_local.relative_pos_in_block && - local.identifier == other_local.identifier { + new_variable.identifier == old_variable.identifier { // Collision return Err( ParseError::new_error_str_at_span( - &ctx.module().source, local.identifier.span, "Local variable name conflicts with another variable" + &ctx.module().source, new_variable.identifier.span, "Local variable name conflicts with another variable" ).with_info_str_at_span( - &ctx.module().source, other_local.identifier.span, "Previous variable is found here" + &ctx.module().source, old_variable.identifier.span, "Previous variable is found here" ) ); } } // No collisions - let block = &mut ctx.heap[scope.to_block()]; - block.locals.push(id); + let scope = &mut ctx.heap[scope_id]; + scope.variables.push(new_variable_id); - let local = &mut ctx.heap[id]; - local.relative_pos_in_block = relative_pos; + let variable = &mut ctx.heap[new_variable_id]; + variable.relative_pos_in_parent = relative_pos; Ok(()) } @@ -1728,85 +1731,66 @@ impl PassValidationLinking { /// Finds a variable in the visitor's scope that must appear before the /// specified relative position within that block. fn find_variable(&self, ctx: &Ctx, mut relative_pos: i32, identifier: &Identifier) -> Option { - debug_assert!(self.cur_scope.is_block()); + let mut scope_id = self.cur_scope; - // No need to use iterator over namespaces if here - let mut scope = &self.cur_scope; - loop { - debug_assert!(scope.is_block()); - let block = &ctx.heap[scope.to_block()]; + // Check if we can find the variable in the current scope + let scope = &ctx.heap[scope_id]; - for local_id in &block.locals { - let local = &ctx.heap[*local_id]; + for variable_id in scope.variables.iter().copied() { + let variable = &ctx.heap[variable_id]; - if local.relative_pos_in_block < relative_pos && identifier == &local.identifier { - return Some(*local_id); + if variable.relative_pos_in_parent < relative_pos && identifier == &variable.identifier { + return Some(variable_id); } } - scope = &block.scope_node.parent; - if !scope.is_block() { - // Definition scope, need to check arguments to definition - match scope { - Scope::Definition(definition_id) => { - let definition = &ctx.heap[*definition_id]; - for parameter_id in definition.parameters() { - let parameter = &ctx.heap[*parameter_id]; - if identifier == ¶meter.identifier { - return Some(*parameter_id); - } - } - }, - _ => unreachable!(), - } - - // Variable could not be found - return None - } else { - relative_pos = block.scope_node.relative_pos_in_parent; + // Could not find variable, move to parent scope and try again + if scope.parent.is_none() { + return None; } + + scope_id = scope.parent.unwrap(); + relative_pos = scope.relative_pos_in_parent; } } /// Adds a particular label to the current scope. Will return an error if /// there is another label with the same name visible in the current scope. - fn checked_add_label(&mut self, ctx: &mut Ctx, relative_pos: i32, in_sync: SynchronousStatementId, id: LabeledStatementId) -> Result<(), ParseError> { - debug_assert!(self.cur_scope.is_block()); - + fn checked_add_label(&mut self, ctx: &mut Ctx, relative_pos: i32, in_sync: SynchronousStatementId, new_label_id: LabeledStatementId) -> Result<(), ParseError> { // Make sure label is not defined within the current scope or any of the // parent scope. - let label = &mut ctx.heap[id]; - label.relative_pos_in_block = relative_pos; - label.in_sync = in_sync; + let new_label = &mut ctx.heap[new_label_id]; + new_label.relative_pos_in_parent = relative_pos; + new_label.in_sync = in_sync; - let label = &ctx.heap[id]; - let mut scope = &self.cur_scope; + let new_label = &ctx.heap[new_label_id]; + let mut scope_id = self.cur_scope; loop { - debug_assert!(scope.is_block(), "scope is not a block"); - let block = &ctx.heap[scope.to_block()]; - for other_label_id in &block.labels { - let other_label = &ctx.heap[*other_label_id]; - if other_label.label == label.label { + let scope = &ctx.heap[scope_id]; + for existing_label_id in scope.labels.iter().copied() { + let existing_label = &ctx.heap[existing_label_id]; + if existing_label.label == new_label.label { // Collision return Err(ParseError::new_error_str_at_span( - &ctx.module().source, label.label.span, "label name is used more than once" + &ctx.module().source, new_label.label.span, "label name is used more than once" ).with_info_str_at_span( - &ctx.module().source, other_label.label.span, "the other label is found here" + &ctx.module().source, existing_label.label.span, "the other label is found here" )); } } - scope = &block.scope_node.parent; - if !scope.is_block() { + if scope.parent.is_none() { break; } + + scope_id = scope.parent.unwrap(); } // No collisions - let block = &mut ctx.heap[self.cur_scope.to_block()]; - block.labels.push(id); + let scope = &mut ctx.heap[self.cur_scope]; + scope.labels.push(new_label_id); Ok(()) } @@ -1814,60 +1798,57 @@ impl PassValidationLinking { /// Finds a particular labeled statement by its identifier. Once found it /// will make sure that the target label does not skip over any variable /// declarations within the scope in which the label was found. - fn find_label(mut scope: Scope, ctx: &Ctx, identifier: &Identifier) -> Result { - debug_assert!(scope.is_block()); - + fn find_label(mut scope_id: ScopeId, ctx: &Ctx, identifier: &Identifier) -> Result { loop { - debug_assert!(scope.is_block(), "scope is not a block"); - let relative_scope_pos = ctx.heap[scope.to_block()].scope_node.relative_pos_in_parent; + let scope = &ctx.heap[scope_id]; + let relative_scope_pos = scope.relative_pos_in_parent; - let block = &ctx.heap[scope.to_block()]; - for label_id in &block.labels { - let label = &ctx.heap[*label_id]; + for label_id in scope.labels.iter().copied() { + let label = &ctx.heap[label_id]; if label.label == *identifier { - for local_id in &block.locals { + // Found the target label, now make sure that the jump to + // the label doesn't imply a skipped variable declaration + for variable_id in scope.variables.iter().copied() { // TODO: Better to do this in control flow analysis, it // is legal to skip over a variable declaration if it - // is not actually being used. I might be missing - // something here when laying out the bytecode... - let local = &ctx.heap[*local_id]; - if local.relative_pos_in_block > relative_scope_pos && local.relative_pos_in_block < label.relative_pos_in_block { + // is not actually being used. + let variable = &ctx.heap[variable_id]; + if variable.relative_pos_in_parent > relative_scope_pos && variable.relative_pos_in_parent < label.relative_pos_in_parent { return Err( ParseError::new_error_str_at_span(&ctx.module().source, identifier.span, "this target label skips over a variable declaration") .with_info_str_at_span(&ctx.module().source, label.label.span, "because it jumps to this label") - .with_info_str_at_span(&ctx.module().source, local.identifier.span, "which skips over this variable") + .with_info_str_at_span(&ctx.module().source, variable.identifier.span, "which skips over this variable") ); } } - return Ok(*label_id); + return Ok(label_id); } } - scope = block.scope_node.parent; - if !scope.is_block() { + if scope.parent.is_none() { return Err(ParseError::new_error_str_at_span( &ctx.module().source, identifier.span, "could not find this label" )); } + scope_id = scope.parent.unwrap(); } } - /// This function will check if the provided while statement ID has a block - /// statement that is one of our current parents. - fn has_parent_while_scope(mut scope: Scope, ctx: &Ctx, id: WhileStatementId) -> bool { - let while_stmt = &ctx.heap[id]; + /// This function will check if the provided scope has a parent that belongs + /// to a while statement. + fn scope_is_nested_in_while_statement(mut scope_id: ScopeId, ctx: &Ctx, expected_while_id: WhileStatementId) -> bool { + let while_stmt = &ctx.heap[expected_while_id]; + loop { - debug_assert!(scope.is_block()); - let block = scope.to_block(); - if while_stmt.body == block { + let scope = &ctx.heap[scope_id]; + if scope.this == while_stmt.scope { return true; } - let block = &ctx.heap[block]; - scope = block.scope_node.parent; - if !scope.is_block() { - return false; + match scope.parent { + Some(new_scope_id) => scope_id = new_scope_id, + None => return false, // walked all the way up, not encountering the while statement } } } @@ -1886,9 +1867,10 @@ impl PassValidationLinking { // Make sure break target is a while statement let target = &ctx.heap[target_id]; if let Statement::While(target_stmt) = &ctx.heap[target.body] { - // Even though we have a target while statement, the break might not be - // present underneath this particular labeled while statement - if !Self::has_parent_while_scope(control_flow.in_scope, ctx, target_stmt.this) { + // Even though we have a target while statement, the control + // flow statement might not be present underneath this + // particular labeled while statement. + if !Self::scope_is_nested_in_while_statement(control_flow.in_scope, ctx, target_stmt.this) { return Err(ParseError::new_error_str_at_span( &ctx.module().source, label.span, "break statement is not nested under the target label's while statement" ).with_info_str_at_span( diff --git a/src/protocol/parser/type_table.rs b/src/protocol/parser/type_table.rs index de0d8cc5461dc9dee9fa87b537774f81a014d8e8..f6e909cd3196a83f0206091359c3cdbbeed0d611 100644 --- a/src/protocol/parser/type_table.rs +++ b/src/protocol/parser/type_table.rs @@ -199,7 +199,7 @@ pub struct StructField { /// `FunctionType` is what you expect it to be: a particular function's /// signature. pub struct FunctionType { - pub return_types: Vec, + pub return_type: ParserType, pub arguments: Vec, } @@ -935,12 +935,9 @@ impl TypeTable { let root_id = definition.defined_in; // Check and construct return types and argument types. - debug_assert_eq!(definition.return_types.len(), 1, "not one return type"); - for return_type in &definition.return_types { - Self::check_member_parser_type( - modules, ctx, root_id, return_type, definition.builtin - )?; - } + Self::check_member_parser_type( + modules, ctx, root_id, &definition.return_type, definition.builtin + )?; let mut arguments = Vec::with_capacity(definition.parameters.len()); for parameter_id in &definition.parameters { @@ -963,9 +960,8 @@ impl TypeTable { // Construct internal representation of function type let mut poly_vars = Self::create_polymorphic_variables(&definition.poly_vars); - for return_type in &definition.return_types { - Self::mark_used_polymorphic_variables(&mut poly_vars, return_type); - } + + Self::mark_used_polymorphic_variables(&mut poly_vars, &definition.return_type); for argument in &arguments { Self::mark_used_polymorphic_variables(&mut poly_vars, &argument.parser_type); } @@ -975,7 +971,7 @@ impl TypeTable { self.type_lookup.insert(definition_id, DefinedType{ ast_root: root_id, ast_definition: definition_id, - definition: DefinedTypeVariant::Function(FunctionType{ return_types: definition.return_types.clone(), arguments }), + definition: DefinedTypeVariant::Function(FunctionType{ return_type: definition.return_type.clone(), arguments }), poly_vars, is_polymorph }); diff --git a/src/protocol/parser/visitor.rs b/src/protocol/parser/visitor.rs index 3d62f1c603b4b29b8d640dee5b355d8186fc3383..ad92f4dd13b1a946e00f40c5aabed836c76e5938 100644 --- a/src/protocol/parser/visitor.rs +++ b/src/protocol/parser/visitor.rs @@ -6,8 +6,10 @@ use crate::protocol::symbol_table::{SymbolTable}; type Unit = (); pub(crate) type VisitorResult = Result; -/// Globally configured vector capacity for buffers in visitor implementations -pub(crate) const BUFFER_INIT_CAPACITY: usize = 256; +/// Globally configured capacity for large-ish buffers in visitor impls +pub(crate) const BUFFER_INIT_CAP_LARGE: usize = 256; +/// Globally configured capacity for small-ish buffers in visitor impls +pub(crate) const BUFFER_INIT_CAP_SMALL: usize = 64; /// General context structure that is used while traversing the AST. pub(crate) struct Ctx<'p> { diff --git a/src/protocol/tests/utils.rs b/src/protocol/tests/utils.rs index 63ff0f95929262c11c44f701f0f31015589658dd..461d2207031d776df65f16a2c52952449c5a7e4a 100644 --- a/src/protocol/tests/utils.rs +++ b/src/protocol/tests/utils.rs @@ -517,15 +517,13 @@ impl<'a> FunctionTester<'a> { pub(crate) fn for_variable(self, name: &str, f: F) -> Self { // Seek through the blocks in order to find the variable - let wrapping_block_id = seek_stmt( - self.ctx.heap, self.def.body.upcast(), - &|stmt| { - if let Statement::Block(block) = stmt { - for local_id in &block.locals { - let var = &self.ctx.heap[*local_id]; - if var.identifier.value.as_str() == name { - return true; - } + let wrapping_scope = seek_scope( + self.ctx.heap, self.def.scope, + &|scope| { + for variable_id in scope.variables.iter().copied() { + let var = &self.ctx.heap[variable_id]; + if var.identifier.value.as_str() == name { + return true; } } @@ -534,13 +532,13 @@ impl<'a> FunctionTester<'a> { ); let mut found_local_id = None; - if let Some(block_id) = wrapping_block_id { - // Found the right block, find the variable inside the block again - let block_stmt = self.ctx.heap[block_id].as_block(); - for local_id in &block_stmt.locals { - let var = &self.ctx.heap[*local_id]; - if var.identifier.value.as_str() == name { - found_local_id = Some(*local_id); + if let Some(scope_id) = wrapping_scope { + // Found the right scope, find the variable inside the block again + let scope = &self.ctx.heap[scope_id]; + for variable_id in scope.variables.iter().copied() { + let variable = &self.ctx.heap[variable_id]; + if variable.identifier.value.as_str() == name { + found_local_id = Some(variable_id); } } } @@ -1098,23 +1096,36 @@ fn seek_stmt bool>(heap: &Heap, start: StatementId, f: &F) }, Statement::Labeled(stmt) => seek_stmt(heap, stmt.body, f), Statement::If(stmt) => { - if let Some(id) = seek_stmt(heap, stmt.true_body.upcast(), f) { + if let Some(id) = seek_stmt(heap, stmt.true_case.body, f) { return Some(id); - } else if let Some(false_body) = stmt.false_body { - if let Some(id) = seek_stmt(heap, false_body.upcast(), f) { + } else if let Some(false_body) = stmt.false_case { + if let Some(id) = seek_stmt(heap, false_body.body, f) { return Some(id); } } None }, - Statement::While(stmt) => seek_stmt(heap, stmt.body.upcast(), f), - Statement::Synchronous(stmt) => seek_stmt(heap, stmt.body.upcast(), f), + Statement::While(stmt) => seek_stmt(heap, stmt.body, f), + Statement::Synchronous(stmt) => seek_stmt(heap, stmt.body, f), _ => None }; matched } +fn seek_scope bool>(heap: &Heap, start: ScopeId, f: &F) -> Option { + let scope = &heap[start]; + if f(scope) { return Some(start); } + + for child_scope_id in scope.nested.iter().copied() { + if let Some(result) = seek_scope(heap, child_scope_id, f) { + return Some(result); + } + } + + return None; +} + fn seek_expr_in_expr bool>(heap: &Heap, start: ExpressionId, f: &F) -> Option { let expr = &heap[start]; if f(expr) { return Some(start); } @@ -1215,9 +1226,9 @@ fn seek_expr_in_stmt bool>(heap: &Heap, start: StatementId Statement::If(stmt) => { None .or_else(|| seek_expr_in_expr(heap, stmt.test, f)) - .or_else(|| seek_expr_in_stmt(heap, stmt.true_body.upcast(), f)) - .or_else(|| if let Some(false_body) = stmt.false_body { - seek_expr_in_stmt(heap, false_body.upcast(), f) + .or_else(|| seek_expr_in_stmt(heap, stmt.true_case.body, f)) + .or_else(|| if let Some(false_body) = stmt.false_case { + seek_expr_in_stmt(heap, false_body.body, f) } else { None }) @@ -1225,10 +1236,10 @@ fn seek_expr_in_stmt bool>(heap: &Heap, start: StatementId Statement::While(stmt) => { None .or_else(|| seek_expr_in_expr(heap, stmt.test, f)) - .or_else(|| seek_expr_in_stmt(heap, stmt.body.upcast(), f)) + .or_else(|| seek_expr_in_stmt(heap, stmt.body, f)) }, Statement::Synchronous(stmt) => { - seek_expr_in_stmt(heap, stmt.body.upcast(), f) + seek_expr_in_stmt(heap, stmt.body, f) }, Statement::Return(stmt) => { for expr_id in &stmt.expressions {