diff --git a/src/protocol/tests/utils.rs b/src/protocol/tests/utils.rs index f3188e3a773a0851341e89df1bd93e3778eea1b2..e530437f860aadee15fe7196090177f6668a3b49 100644 --- a/src/protocol/tests/utils.rs +++ b/src/protocol/tests/utils.rs @@ -139,6 +139,31 @@ impl AstOkTester { ); unreachable!() } + + pub(crate) fn for_function(self, name: &str, f: F) -> Self { + let mut found = false; + for definition in self.heap.definitions.iter() { + if let Definition::Function(definition) = definition { + if String::from_utf8_lossy(&definition.identifier.value) != name { + continue; + } + + // Found function + let tester = FunctionTester::new(&self.test_name, definition, &self.heap); + f(tester); + found = true; + break; + } + } + + if found { return self } + + assert!( + false, "[{}] failed to find definition for function '{}'", + self.test_name, name + ); + unreachable!(); + } } //------------------------------------------------------------------------------ @@ -228,6 +253,55 @@ impl<'a> StructFieldTester<'a> { } } +pub(crate) struct FunctionTester<'a> { + test_name: &'a str, + def: &'a Function, + heap: &'a Heap, +} + +impl<'a> FunctionTester<'a> { + fn new(test_name: &'a str, def: &'a Function, heap: &'a Heap) -> Self { + Self{ test_name, def, heap } + } + + pub(crate) fn for_variable(self, name: &str, f: F) -> Self { + let mem_stmt_id = seek_stmt( + self.heap, self.def.body, + |stmt| { + if let Statement::Local(local) = stmt { + if let LocalStatement::Memory(memory) = local { + let local = &self.heap[memory.variable]; + if local.identifier.value == name.as_bytes() { + return true; + } + } + } + + false + } + ); + + match mem_stmt_id { + Some(mem_stmt_id) => { + // TODO: Retrieve shit + }, + None => { + // TODO: Throw error + } + } + } +} + + +pub(crate) struct VariableTester<'a> { + test_name: &'a str, + def: &'a Local, + assignment: &'a AssignmentExpression, + heap: &'a Heap, +} + + + //------------------------------------------------------------------------------ // Interface for failed compilation //------------------------------------------------------------------------------ @@ -368,4 +442,151 @@ fn serialize_parser_type(buffer: &mut String, heap: &Heap, id: ParserTypeId) { } } } +} + +fn seek_stmt bool>(heap: &Heap, start: StatementId, f: F) -> Option { + let stmt = &heap[start]; + if f(stmt) { return Some(start); } + + // This statement wasn't it, try to recurse + let matched = match stmt { + Statement::Block(block) => { + for sub_id in &block.statements { + if let Some(id) = seek_stmt(heap, *sub_id, f) { + return Some(id); + } + } + + None + }, + Statement::Labeled(stmt) => seek_stmt(heap, stmt.body, f), + Statement::If(stmt) => { + if let Some(id) = seek_stmt(heap,stmt.true_body, f) { + return Some(id); + } else if let Some(id) = seek_stmt(heap, stmt.false_body, f) { + return Some(id); + } + None + }, + Statement::While(stmt) => seek_stmt(heap, stmt.body, f), + Statement::Synchronous(stmt) => seek_stmt(heap, stmt.body, f), + _ => None + }; + + matched +} + +fn seek_expr_in_expr bool>(heap: &Heap, start: ExpressionId, f: F) -> Option { + let expr = &heap[start]; + if f(expr) { return Some(start); } + + match expr { + Expression::Assignment(expr) => { + None + .or_else(|| seek_expr_in_expr(heap, expr.left, f)) + .or_else(|| seek_expr_in_expr(heap, expr.right, f)) + }, + Expression::Conditional(expr) => { + None + .or_else(|| seek_expr_in_expr(heap, expr.test, f)) + .or_else(|| seek_expr_in_expr(heap, expr.true_expression, f)) + .or_else(|| seek_expr_in_expr(heap, expr.false_expression, f)) + }, + Expression::Binary(expr) => { + None + .or_else(|| seek_expr_in_expr(heap, expr.left, f)) + .or_else(|| seek_expr_in_expr(heap, expr.right, f)) + }, + Expression::Unary(expr) => { + seek_expr_in_expr(heap, expr.expression, f) + }, + Expression::Indexing(expr) => { + None + .or_else(|| seek_expr_in_expr(heap, expr.subject, f)) + .or_else(|| seek_expr_in_expr(heap, expr.index, f)) + }, + Expression::Slicing(expr) => { + None + .or_else(|| seek_expr_in_expr(heap, expr.subject, f)) + .or_else(|| seek_expr_in_expr(heap, expr.from_index, f)) + .or_else(|| seek_expr_in_expr(heap, expr.to_index, f)) + }, + Expression::Select(expr) => { + seek_expr_in_expr(heap, expr.subject, f) + }, + Expression::Array(expr) => { + for element in &expr.elements { + if let Some(id) = seek_expr_in_expr(heap, *element, f) { + return Some(id) + } + } + None + }, + Expression::Literal(expr) => { + if let Literal::Struct(lit) = expr.value { + for field in &lit.fields { + if let Some(id) = seek_expr_in_expr(heap, field.value, f) { + return Some(id) + } + } + } + None + }, + Expression::Call(expr) => { + for arg in &expr.arguments { + if let Some(id) = seek_expr_in_expr(heap, *arg, f) { + return Some(id) + } + } + None + }, + Expression::Variable(expr) => { + None + } + } +} + +fn seek_expr_in_stmt bool>(heap: &Heap, start: StatementId, f: F) -> Option { + let stmt = &heap[start]; + + match stmt { + Statement::Block(stmt) => { + for stmt_id in &stmt.statements { + if let Some(id) = seek_expr_in_stmt(heap, *stmt_id, f) { + return Some(id) + } + } + None + }, + Statement::Labeled(stmt) => { + seek_expr_in_stmt(heap, stmt.body, f) + }, + Statement::If(stmt) => { + None + .or_else(|| seek_expr_in_expr(heap, stmt.test, f)) + .or_else(|| seek_expr_in_stmt(heap, stmt.true_body, f)) + .or_else(|| seek_expr_in_stmt(heap, stmt.false_body, f)) + }, + Statement::While(stmt) => { + None + .or_else(|| seek_expr_in_expr(heap, stmt.test, f)) + .or_else(|| seek_expr_in_stmt(heap, stmt.body, f)) + }, + Statement::Synchronous(stmt) => { + seek_expr_in_stmt(heap, stmt.body, f) + }, + Statement::Return(stmt) => { + seek_expr_in_expr(heap, stmt.expression, f) + }, + Statement::Assert(stmt) => { + seek_expr_in_expr(heap, stmt.expression, f) + }, + Statement::New(stmt) => { + seek_expr_in_expr(heap, stmt.expression.upcast(), f) + }, + Statement::Expression(stmt) => { + seek_expr_in_expr(heap, stmt.expression, f) + }, + _ => None + } } \ No newline at end of file