diff --git a/src/protocol/tests/utils.rs b/src/protocol/tests/utils.rs index 63ff0f95929262c11c44f701f0f31015589658dd..461d2207031d776df65f16a2c52952449c5a7e4a 100644 --- a/src/protocol/tests/utils.rs +++ b/src/protocol/tests/utils.rs @@ -517,15 +517,13 @@ impl<'a> FunctionTester<'a> { pub(crate) fn for_variable(self, name: &str, f: F) -> Self { // Seek through the blocks in order to find the variable - let wrapping_block_id = seek_stmt( - self.ctx.heap, self.def.body.upcast(), - &|stmt| { - if let Statement::Block(block) = stmt { - for local_id in &block.locals { - let var = &self.ctx.heap[*local_id]; - if var.identifier.value.as_str() == name { - return true; - } + let wrapping_scope = seek_scope( + self.ctx.heap, self.def.scope, + &|scope| { + for variable_id in scope.variables.iter().copied() { + let var = &self.ctx.heap[variable_id]; + if var.identifier.value.as_str() == name { + return true; } } @@ -534,13 +532,13 @@ impl<'a> FunctionTester<'a> { ); let mut found_local_id = None; - if let Some(block_id) = wrapping_block_id { - // Found the right block, find the variable inside the block again - let block_stmt = self.ctx.heap[block_id].as_block(); - for local_id in &block_stmt.locals { - let var = &self.ctx.heap[*local_id]; - if var.identifier.value.as_str() == name { - found_local_id = Some(*local_id); + if let Some(scope_id) = wrapping_scope { + // Found the right scope, find the variable inside the block again + let scope = &self.ctx.heap[scope_id]; + for variable_id in scope.variables.iter().copied() { + let variable = &self.ctx.heap[variable_id]; + if variable.identifier.value.as_str() == name { + found_local_id = Some(variable_id); } } } @@ -1098,23 +1096,36 @@ fn seek_stmt bool>(heap: &Heap, start: StatementId, f: &F) }, Statement::Labeled(stmt) => seek_stmt(heap, stmt.body, f), Statement::If(stmt) => { - if let Some(id) = seek_stmt(heap, stmt.true_body.upcast(), f) { + if let Some(id) = seek_stmt(heap, stmt.true_case.body, f) { return Some(id); - } else if let Some(false_body) = stmt.false_body { - if let Some(id) = seek_stmt(heap, false_body.upcast(), f) { + } else if let Some(false_body) = stmt.false_case { + if let Some(id) = seek_stmt(heap, false_body.body, f) { return Some(id); } } None }, - Statement::While(stmt) => seek_stmt(heap, stmt.body.upcast(), f), - Statement::Synchronous(stmt) => seek_stmt(heap, stmt.body.upcast(), f), + Statement::While(stmt) => seek_stmt(heap, stmt.body, f), + Statement::Synchronous(stmt) => seek_stmt(heap, stmt.body, f), _ => None }; matched } +fn seek_scope bool>(heap: &Heap, start: ScopeId, f: &F) -> Option { + let scope = &heap[start]; + if f(scope) { return Some(start); } + + for child_scope_id in scope.nested.iter().copied() { + if let Some(result) = seek_scope(heap, child_scope_id, f) { + return Some(result); + } + } + + return None; +} + fn seek_expr_in_expr bool>(heap: &Heap, start: ExpressionId, f: &F) -> Option { let expr = &heap[start]; if f(expr) { return Some(start); } @@ -1215,9 +1226,9 @@ fn seek_expr_in_stmt bool>(heap: &Heap, start: StatementId Statement::If(stmt) => { None .or_else(|| seek_expr_in_expr(heap, stmt.test, f)) - .or_else(|| seek_expr_in_stmt(heap, stmt.true_body.upcast(), f)) - .or_else(|| if let Some(false_body) = stmt.false_body { - seek_expr_in_stmt(heap, false_body.upcast(), f) + .or_else(|| seek_expr_in_stmt(heap, stmt.true_case.body, f)) + .or_else(|| if let Some(false_body) = stmt.false_case { + seek_expr_in_stmt(heap, false_body.body, f) } else { None }) @@ -1225,10 +1236,10 @@ fn seek_expr_in_stmt bool>(heap: &Heap, start: StatementId Statement::While(stmt) => { None .or_else(|| seek_expr_in_expr(heap, stmt.test, f)) - .or_else(|| seek_expr_in_stmt(heap, stmt.body.upcast(), f)) + .or_else(|| seek_expr_in_stmt(heap, stmt.body, f)) }, Statement::Synchronous(stmt) => { - seek_expr_in_stmt(heap, stmt.body.upcast(), f) + seek_expr_in_stmt(heap, stmt.body, f) }, Statement::Return(stmt) => { for expr_id in &stmt.expressions {