diff --git a/src/protocol/tests/utils.rs b/src/protocol/tests/utils.rs index f5aaff9b5281820d684fff34c06aab35d63ddea9..5b1c07dd45fe71c3196c467b4d7bacd49b7c2e97 100644 --- a/src/protocol/tests/utils.rs +++ b/src/protocol/tests/utils.rs @@ -161,6 +161,31 @@ impl AstOkTester { unreachable!() } + pub(crate) fn for_enum(self, name: &str, f: F) -> Self { + let mut found = false; + for definition in self.heap.definitions.iter() { + if let Definition::Enum(definition) = definition { + if String::from_utf8_lossy(&definition.identifier.value) != name { + continue; + } + + // Found enum with the same name + let tester = EnumTester::new(self.ctx(), definition); + f(tester); + found = true; + break; + } + } + + if found { return self } + + assert!( + false, "[{}] Failed to find definition for enum '{}'", + self.test_name, name + ); + unreachable!() + } + pub(crate) fn for_function(self, name: &str, f: F) -> Self { let mut found = false; for definition in self.heap.definitions.iter() { @@ -221,11 +246,10 @@ impl<'a> StructTester<'a> { } pub(crate) fn assert_num_monomorphs(self, num: usize) -> Self { - let type_def = self.ctx.types.get_base_definition(&self.def.this.upcast()).unwrap(); - assert_eq!( - num, type_def.monomorphs.len(), - "[{}] Expected {} monomorphs, but found {} for {}", - self.ctx.test_name, num, type_def.monomorphs.len(), self.assert_postfix() + let (is_equal, num_encountered) = has_equal_num_monomorphs(self.ctx, num, self.def.this.upcast()); + assert!( + is_equal, "[{}] Expected {} monomorphs, but got {} for {}", + self.ctx.test_name, num, num_encountered, self.assert_postfix() ); self } @@ -233,35 +257,10 @@ impl<'a> StructTester<'a> { /// Asserts that a monomorph exist, separate polymorphic variable types by /// a semicolon. pub(crate) fn assert_has_monomorph(self, serialized_monomorph: &str) -> Self { - let definition_id = self.def.this.upcast(); - let type_def = self.ctx.types.get_base_definition(&definition_id).unwrap(); - - let mut full_buffer = String::new(); - full_buffer.push('['); - for (monomorph_idx, monomorph) in type_def.monomorphs.iter().enumerate() { - let mut buffer = String::new(); - for (element_idx, monomorph_element) in monomorph.iter().enumerate() { - if element_idx != 0 { buffer.push(';'); } - serialize_concrete_type(&mut buffer, self.ctx.heap, definition_id, monomorph_element); - } - - if buffer == serialized_monomorph { - // Found an exact match - return self - } - - if monomorph_idx != 0 { - full_buffer.push_str(", "); - } - full_buffer.push('"'); - full_buffer.push_str(&buffer); - full_buffer.push('"'); - } - full_buffer.push(']'); - + let (has_monomorph, serialized) = has_monomorph(self.ctx, self.def.this.upcast(), serialized_monomorph); assert!( - false, "[{}] Expected to find monomorph {}, but got {} for {}", - self.ctx.test_name, serialized_monomorph, &full_buffer, self.assert_postfix() + has_monomorph, "[{}] Expected to find monomorph {}, but got {} for {}", + self.ctx.test_name, serialized_monomorph, &serialized, self.assert_postfix() ); self } @@ -328,6 +327,57 @@ impl<'a> StructFieldTester<'a> { } } +pub(crate) struct EnumTester<'a> { + ctx: TestCtx<'a>, + def: &'a EnumDefinition, +} + +impl<'a> EnumTester<'a> { + fn new(ctx: TestCtx<'a>, def: &'a EnumDefinition) -> Self { + Self{ ctx, def } + } + + pub(crate) fn assert_num_variants(self, num: usize) -> Self { + assert_eq!( + num, self.def.variants.len(), + "[{}] Expected {} enum variants, but found {} for {}", + self.ctx.test_name, num, self.def.variants.len(), self.assert_postfix() + ); + self + } + + pub(crate) fn assert_num_monomorphs(self, num: usize) -> Self { + let (is_equal, num_encountered) = has_equal_num_monomorphs(self.ctx, num, self.def.this.upcast()); + assert!( + is_equal, "[{}] Expected {} monomorphs, but got {} for {}", + self.ctx.test_name, num, num_encountered, self.assert_postfix() + ); + self + } + + pub(crate) fn assert_has_monomorph(self, serialized_monomorph: &str) -> Self { + let (has_monomorph, serialized) = has_monomorph(self.ctx, self.def.this.upcast(), serialized_monomorph); + assert!( + has_monomorph, "[{}] Expected to find monomorph {}, but got {} for {}", + self.ctx.test_name, serialized_monomorph, serialized, self.assert_postfix() + ); + self + } + + pub(crate) fn assert_postfix(&self) -> String { + let mut v = String::new(); + v.push_str("Enum{ name: "); + v.push_str(&String::from_utf8_lossy(&self.def.identifier.value)); + v.push_str(", variants: ["); + for (variant_idx, variant) in self.def.variants.iter().enumerate() { + if variant_idx != 0 { v.push_str(", "); } + v.push_str(&String::from_utf8_lossy(&variant.identifier.value)); + } + v.push_str("] }"); + v + } +} + pub(crate) struct FunctionTester<'a> { ctx: TestCtx<'a>, def: &'a Function, @@ -645,6 +695,43 @@ impl<'a> ErrorTester<'a> { // Generic utilities //------------------------------------------------------------------------------ +fn has_equal_num_monomorphs<'a>(ctx: TestCtx<'a>, num: usize, definition_id: DefinitionId) -> (bool, usize) { + let type_def = ctx.types.get_base_definition(&definition_id).unwrap(); + let num_on_type = type_def.monomorphs.len(); + + (num_on_type == num, num_on_type) +} + +fn has_monomorph<'a>(ctx: TestCtx<'a>, definition_id: DefinitionId, serialized_monomorph: &str) -> (bool, String) { + let type_def = ctx.types.get_base_definition(&definition_id).unwrap(); + + let mut full_buffer = String::new(); + let mut has_match = false; + full_buffer.push('['); + for (monomorph_idx, monomorph) in type_def.monomorphs.iter().enumerate() { + let mut buffer = String::new(); + for (element_idx, monomorph_element) in monomorph.iter().enumerate() { + if element_idx != 0 { buffer.push(';'); } + serialize_concrete_type(&mut buffer, ctx.heap, definition_id, monomorph_element); + } + + if buffer == serialized_monomorph { + // Found an exact match + has_match = true; + } + + if monomorph_idx != 0 { + full_buffer.push_str(", "); + } + full_buffer.push('"'); + full_buffer.push_str(&buffer); + full_buffer.push('"'); + } + full_buffer.push(']'); + + (has_match, full_buffer) +} + fn serialize_parser_type(buffer: &mut String, heap: &Heap, id: ParserTypeId) { use ParserTypeVariant as PTV; @@ -740,12 +827,14 @@ fn serialize_concrete_type(buffer: &mut String, heap: &Heap, def: DefinitionId, CTP::Instance(definition_id, num_sub) => { let definition_name = heap[*definition_id].identifier(); buffer.push_str(&String::from_utf8_lossy(&definition_name.value)); - buffer.push('<'); - for sub_idx in 0..*num_sub { - if sub_idx != 0 { buffer.push(','); } - idx = serialize_recursive(buffer, heap, poly_vars, concrete, idx + 1); + if *num_sub != 0 { + buffer.push('<'); + for sub_idx in 0..*num_sub { + if sub_idx != 0 { buffer.push(','); } + idx = serialize_recursive(buffer, heap, poly_vars, concrete, idx + 1); + } + buffer.push('>'); } - buffer.push('>'); idx += 1; } }