diff --git a/src/protocol/tests/utils.rs b/src/protocol/tests/utils.rs index d599014400b8b193e8bded700bdcba2fdfeafb65..802e5625a6ccb1640a5a36cea919568fada4d3d7 100644 --- a/src/protocol/tests/utils.rs +++ b/src/protocol/tests/utils.rs @@ -218,7 +218,7 @@ impl AstOkTester { pub(crate) fn for_function(self, name: &str, f: F) -> Self { let mut found = false; for definition in self.heap.definitions.iter() { - if let Definition::Function(definition) = definition { + if let Definition::Procedure(definition) = definition { if definition.identifier.value.as_str() != name { continue; } @@ -507,11 +507,11 @@ impl<'a> UnionTester<'a> { pub(crate) struct FunctionTester<'a> { ctx: TestCtx<'a>, - def: &'a FunctionDefinition, + def: &'a ProcedureDefinition, } impl<'a> FunctionTester<'a> { - fn new(ctx: TestCtx<'a>, def: &'a FunctionDefinition) -> Self { + fn new(ctx: TestCtx<'a>, def: &'a ProcedureDefinition) -> Self { Self{ ctx, def } } @@ -700,11 +700,11 @@ impl<'a> FunctionTester<'a> { use crate::protocol::*; // Assuming the function is not polymorphic - let definition_id = self.def.this.upcast(); + let definition_id = self.def.this; let func_type = [ConcreteTypePart::Function(definition_id, 0)]; - let mono_index = self.ctx.types.get_procedure_monomorph_type_id(&definition_id, &func_type).unwrap(); + let mono_index = self.ctx.types.get_procedure_monomorph_type_id(&definition_id.upcast(), &func_type).unwrap(); - let mut prompt = Prompt::new(&self.ctx.types, &self.ctx.heap, self.def.this.upcast(), mono_index, ValueGroup::new_stack(Vec::new())); + let mut prompt = Prompt::new(&self.ctx.types, &self.ctx.heap, definition_id, mono_index, ValueGroup::new_stack(Vec::new())); let mut call_context = FakeRunContext{}; loop { let result = prompt.step(&self.ctx.types, &self.ctx.heap, &self.ctx.modules, &mut call_context); @@ -806,14 +806,11 @@ impl<'a> ExpressionTester<'a> { } fn get_procedure_monomorph<'a>(heap: &Heap, types: &'a TypeTable, definition_id: DefinitionId) -> &'a ProcedureMonomorph { - let ast_definition = &heap[definition_id]; - let func_type = if ast_definition.is_function() { - [ConcreteTypePart::Function(definition_id, 0)] - } else if ast_definition.is_component() { - [ConcreteTypePart::Component(definition_id, 0)] + let ast_definition = heap[definition_id].as_procedure(); + let func_type = if ast_definition.kind == ProcedureKind::Function { + [ConcreteTypePart::Function(ast_definition.this, 0)] } else { - assert!(false); - unreachable!() + [ConcreteTypePart::Component(ast_definition.this, 0)] }; let mono_index = types.get_procedure_monomorph_type_id(&definition_id, &func_type).unwrap(); @@ -928,10 +925,14 @@ fn has_equal_num_monomorphs(ctx: TestCtx, num: usize, definition_id: DefinitionI for mono in &ctx.types.mono_types { match &mono.concrete_type.parts[0] { - ConcreteTypePart::Instance(def_id, _) | + ConcreteTypePart::Instance(def_id, _) => { + if *def_id == definition_id { + num_on_type += 1; + } + } ConcreteTypePart::Function(def_id, _) | ConcreteTypePart::Component(def_id, _) => { - if *def_id == definition_id { + if def_id.upcast() == definition_id { num_on_type += 1; } }, @@ -964,11 +965,11 @@ fn has_monomorph(ctx: TestCtx, definition_id: DefinitionId, serialized_monomorph }; // Bit wasteful, but this is (temporary?) testing code: - for (mono_idx, mono) in ctx.types.mono_types.iter().enumerate() { + for (_mono_idx, mono) in ctx.types.mono_types.iter().enumerate() { let got_definition_id = match &mono.concrete_type.parts[0] { - ConcreteTypePart::Instance(v, _) | + ConcreteTypePart::Instance(v, _) => *v, ConcreteTypePart::Function(v, _) | - ConcreteTypePart::Component(v, _) => *v, + ConcreteTypePart::Component(v, _) => v.upcast(), _ => DefinitionId::new_invalid(), }; if got_definition_id == definition_id {