diff --git a/src/protocol/tests/utils.rs b/src/protocol/tests/utils.rs index 209c50695e626572d8575659c5804a7db511f15a..4c15ce3e30162774258dad5a945057904ac1862b 100644 --- a/src/protocol/tests/utils.rs +++ b/src/protocol/tests/utils.rs @@ -398,6 +398,60 @@ impl<'a> FunctionTester<'a> { self } + /// Finds a specific expression within a function. There are two matchers: + /// one outer matcher (to find a rough indication of the expression) and an + /// inner matcher to find the exact expression. + /// + /// The reason being that, for example, a function's body might be littered + /// with addition symbols, so we first match on "some_var + some_other_var", + /// and then match exactly on "+". + pub(crate) fn for_expression_by_source(self, outer_match: &str, inner_match: &str, f: F) -> Self { + // Seek the expression in the source code + assert!(outer_match.contains(inner_match), "improper testing code"); + + let module = seek_def_in_modules( + &self.ctx.heap, &self.ctx.modules, self.def.this.upcast() + ).unwrap(); + + // Find the first occurrence of the expression after the definition of + // the function, we'll check that it is included in the body later. + let mut outer_match_idx = self.def.position.offset; + while outer_match_idx < module.source.input.len() { + if module.source.input[outer_match_idx..].starts_with(outer_match.as_bytes()) { + break; + } + outer_match_idx += 1 + } + + assert!( + outer_match_idx < module.source.input.len(), + "[{}] Failed to find '{}' within the source that contains {}", + self.ctx.test_name, outer_match, self.assert_postfix() + ); + let inner_match_idx = outer_match_idx + outer_match.find(inner_match).unwrap(); + + // Use the inner match index to find the expression + let expr_id = seek_expr_in_stmt( + &self.ctx.heap, self.def.body, + &|expr| expr.position().offset == inner_match_idx + ); + assert!( + expr_id.is_some(), + "[{}] Failed to find '{}' within the source that contains {} \ + (note: expression was found, but not within the specified function", + self.ctx.test_name, outer_match, self.assert_postfix() + ); + let expr_id = expr_id.unwrap(); + + // We have the expression, call the testing function + let tester = ExpressionTester::new( + self.ctx, self.def.this.upcast(), &self.ctx.heap[expr_id] + ); + f(tester); + + self + } + fn assert_postfix(&self) -> String { format!( "Function{{ name: {} }}", @@ -406,7 +460,6 @@ impl<'a> FunctionTester<'a> { } } - pub(crate) struct VariableTester<'a> { ctx: TestCtx<'a>, definition_id: DefinitionId, @@ -456,6 +509,42 @@ impl<'a> VariableTester<'a> { } } +pub(crate) struct ExpressionTester<'a> { + ctx: TestCtx<'a>, + definition_id: DefinitionId, // of the enclosing function/component + expr: &'a Expression +} + +impl<'a> ExpressionTester<'a> { + fn new( + ctx: TestCtx<'a>, definition_id: DefinitionId, expr: &'a Expression + ) -> Self { + Self{ ctx, definition_id, expr } + } + + pub(crate) fn assert_concrete_type(self, expected: &str) -> Self { + let mut serialized = String::new(); + serialize_concrete_type( + &mut serialized, self.ctx.heap, self.definition_id, + self.expr.get_type() + ); + + assert_eq!( + expected, &serialized, + "[{}] Expected concrete type '{}', but got '{}' for {}", + self.ctx.test_name, expected, &serialized, self.assert_postfix() + ); + self + } + + fn assert_postfix(&self) -> String { + format!( + "Expression{{ debug: {:?} }}", + self.expr + ) + } +} + //------------------------------------------------------------------------------ // Interface for failed compilation //------------------------------------------------------------------------------ @@ -602,9 +691,10 @@ fn serialize_concrete_type(buffer: &mut String, heap: &Heap, def: DefinitionId, // Retrieve polymorphic variables, if present (since we're dealing with a // concrete type we only expect procedure types) let poly_vars = match &heap[def] { - Definition::Function(func) => &func.poly_vars, - Definition::Component(comp) => &comp.poly_vars, - _ => unreachable!("Error in testing utility: did not expect non-procedure type for concrete type serialization"), + Definition::Function(definition) => &definition.poly_vars, + Definition::Component(definition) => &definition.poly_vars, + Definition::Struct(definition) => &definition.poly_vars, + _ => unreachable!("Error in testing utility: unexpected type for concrete type serialization"), }; fn serialize_recursive( @@ -666,6 +756,19 @@ fn serialize_concrete_type(buffer: &mut String, heap: &Heap, def: DefinitionId, serialize_recursive(buffer, heap, poly_vars, concrete, 0); } +fn seek_def_in_modules<'a>(heap: &Heap, modules: &'a [LexedModule], def_id: DefinitionId) -> Option<&'a LexedModule> { + for module in modules { + let root = &heap.protocol_descriptions[module.root_id]; + for definition in &root.definitions { + if *definition == def_id { + return Some(module) + } + } + } + + None +} + fn seek_stmt bool>(heap: &Heap, start: StatementId, f: &F) -> Option { let stmt = &heap[start]; if f(stmt) { return Some(start); }