diff --git a/src/protocol/tests/utils.rs b/src/protocol/tests/utils.rs index ece23a2f6126c52d67029237c90612ac3d21ddfb..b3758135f2ce67f881e0fb6912ff9fe436ab87b3 100644 --- a/src/protocol/tests/utils.rs +++ b/src/protocol/tests/utils.rs @@ -1,10 +1,11 @@ use crate::protocol::{ ast::*, - inputsource::*, + input_source::*, parser::{ *, type_table::TypeTable, symbol_table::SymbolTable, + token_parsing::*, }, }; @@ -63,8 +64,8 @@ impl Tester { pub(crate) fn compile(self) -> AstTesterResult { let mut parser = Parser::new(); for (source_idx, source) in self.sources.into_iter().enumerate() { - let mut cursor = std::io::Cursor::new(source); - let input_source = InputSource::new("", &mut cursor) + let source = source.into_bytes(); + let input_source = InputSource::new(String::from(""), source) .expect(&format!("parsing source {}", source_idx + 1)); if let Err(err) = parser.feed(input_source) { @@ -140,7 +141,7 @@ impl AstOkTester { let mut found = false; for definition in self.heap.definitions.iter() { if let Definition::Struct(definition) = definition { - if String::from_utf8_lossy(&definition.identifier.value) != name { + if definition.identifier.value.as_str() != name { continue; } @@ -163,7 +164,7 @@ impl AstOkTester { let mut found = false; for definition in self.heap.definitions.iter() { if let Definition::Enum(definition) = definition { - if String::from_utf8_lossy(&definition.identifier.value) != name { + if definition.identifier.value.as_str() != name { continue; } @@ -186,7 +187,7 @@ impl AstOkTester { let mut found = false; for definition in self.heap.definitions.iter() { if let Definition::Union(definition) = definition { - if String::from_utf8_lossy(&definition.identifier.value) != name { + if definition.identifier.value.as_str() != name { continue; } @@ -209,7 +210,7 @@ impl AstOkTester { let mut found = false; for definition in self.heap.definitions.iter() { if let Definition::Function(definition) = definition { - if String::from_utf8_lossy(&definition.identifier.value) != name { + if definition.identifier.value.as_str() != name { continue; } @@ -287,7 +288,7 @@ impl<'a> StructTester<'a> { pub(crate) fn for_field(self, name: &str, f: F) -> Self { // Find field with specified name for field in &self.def.fields { - if String::from_utf8_lossy(&field.field.value) == name { + if field.field.value.as_str() == name { let tester = StructFieldTester::new(self.ctx, field); f(tester); return self; @@ -304,11 +305,11 @@ impl<'a> StructTester<'a> { fn assert_postfix(&self) -> String { let mut v = String::new(); v.push_str("Struct{ name: "); - v.push_str(&String::from_utf8_lossy(&self.def.identifier.value)); + v.push_str(self.def.identifier.value.as_str()); v.push_str(", fields: ["); for (field_idx, field) in self.def.fields.iter().enumerate() { if field_idx != 0 { v.push_str(", "); } - v.push_str(&String::from_utf8_lossy(&field.field.value)); + v.push_str(field.field.value.as_str()); } v.push_str("] }"); v @@ -327,7 +328,7 @@ impl<'a> StructFieldTester<'a> { pub(crate) fn assert_parser_type(self, expected: &str) -> Self { let mut serialized_type = String::new(); - serialize_parser_type(&mut serialized_type, &self.ctx.heap, self.def.parser_type); + serialize_parser_type(&mut serialized_type, &self.ctx.heap, &self.def.parser_type); assert_eq!( expected, &serialized_type, "[{}] Expected type '{}', but got '{}' for {}", @@ -338,11 +339,8 @@ impl<'a> StructFieldTester<'a> { fn assert_postfix(&self) -> String { let mut serialized_type = String::new(); - serialize_parser_type(&mut serialized_type, &self.ctx.heap, self.def.parser_type); - format!( - "StructField{{ name: {}, parser_type: {} }}", - String::from_utf8_lossy(&self.def.field.value), serialized_type - ) + serialize_parser_type(&mut serialized_type, &self.ctx.heap, &self.def.parser_type); + format!("StructField{{ name: {}, parser_type: {} }}", self.def.field.value.as_str(), serialized_type) } } @@ -386,11 +384,11 @@ impl<'a> EnumTester<'a> { pub(crate) fn assert_postfix(&self) -> String { let mut v = String::new(); v.push_str("Enum{ name: "); - v.push_str(&String::from_utf8_lossy(&self.def.identifier.value)); + v.push_str(self.def.identifier.value.as_str()); v.push_str(", variants: ["); for (variant_idx, variant) in self.def.variants.iter().enumerate() { if variant_idx != 0 { v.push_str(", "); } - v.push_str(&String::from_utf8_lossy(&variant.identifier.value)); + v.push_str(variant.identifier.value.as_str()); } v.push_str("] }"); v @@ -437,11 +435,11 @@ impl<'a> UnionTester<'a> { fn assert_postfix(&self) -> String { let mut v = String::new(); v.push_str("Union{ name: "); - v.push_str(&String::from_utf8_lossy(&self.def.identifier.value)); + v.push_str(self.def.identifier.value.as_str()); v.push_str(", variants: ["); for (variant_idx, variant) in self.def.variants.iter().enumerate() { if variant_idx != 0 { v.push_str(", "); } - v.push_str(&String::from_utf8_lossy(&variant.identifier.value)); + v.push_str(variant.identifier.value.as_str()); } v.push_str("] }"); v @@ -461,12 +459,12 @@ impl<'a> FunctionTester<'a> { pub(crate) fn for_variable(self, name: &str, f: F) -> Self { // Find the memory statement in order to find the local let mem_stmt_id = seek_stmt( - self.ctx.heap, self.def.body, + self.ctx.heap, self.def.body.upcast(), &|stmt| { if let Statement::Local(local) = stmt { if let LocalStatement::Memory(memory) = local { let local = &self.ctx.heap[memory.variable]; - if local.identifier.value == name.as_bytes() { + if local.identifier.value.as_str() == name { return true; } } @@ -487,7 +485,7 @@ impl<'a> FunctionTester<'a> { // Find the assignment expression that follows it let assignment_id = seek_expr_in_stmt( - self.ctx.heap, self.def.body, + self.ctx.heap, self.def.body.upcast(), &|expr| { if let Expression::Assignment(assign_expr) = expr { if let Expression::Variable(variable_expr) = &self.ctx.heap[assign_expr.left] { @@ -552,7 +550,7 @@ impl<'a> FunctionTester<'a> { // Use the inner match index to find the expression let expr_id = seek_expr_in_stmt( - &self.ctx.heap, self.def.body, + &self.ctx.heap, self.def.body.upcast(), &|expr| expr.position().offset == inner_match_idx ); assert!( @@ -573,10 +571,7 @@ impl<'a> FunctionTester<'a> { } fn assert_postfix(&self) -> String { - format!( - "Function{{ name: {} }}", - &String::from_utf8_lossy(&self.def.identifier.value) - ) + format!("Function{{ name: {} }}", self.def.identifier.value.as_str()) } } @@ -596,7 +591,7 @@ impl<'a> VariableTester<'a> { pub(crate) fn assert_parser_type(self, expected: &str) -> Self { let mut serialized = String::new(); - serialize_parser_type(&mut serialized, self.ctx.heap, self.local.parser_type); + serialize_parser_type(&mut serialized, self.ctx.heap, &self.local.parser_type); assert_eq!( expected, &serialized, @@ -622,11 +617,7 @@ impl<'a> VariableTester<'a> { } fn assert_postfix(&self) -> String { - println!("DEBUG: {:?}", self.assignment.concrete_type); - format!( - "Variable{{ name: {} }}", - &String::from_utf8_lossy(&self.local.identifier.value) - ) + format!("Variable{{ name: {} }}", self.local.identifier.value.as_str()) } } @@ -823,57 +814,83 @@ fn has_monomorph<'a>(ctx: TestCtx<'a>, definition_id: DefinitionId, serialized_m (has_match, full_buffer) } -fn serialize_parser_type(buffer: &mut String, heap: &Heap, id: ParserTypeId) { +fn serialize_parser_type(buffer: &mut String, heap: &Heap, parser_type: &ParserType) { use ParserTypeVariant as PTV; - let p = &heap[id]; - match &p.variant { - PTV::Message => buffer.push_str("msg"), - PTV::Bool => buffer.push_str("bool"), - PTV::Byte => buffer.push_str("byte"), - PTV::Short => buffer.push_str("short"), - PTV::Int => buffer.push_str("int"), - PTV::Long => buffer.push_str("long"), - PTV::String => buffer.push_str("string"), - PTV::IntegerLiteral => buffer.push_str("intlit"), - PTV::Inferred => buffer.push_str("auto"), - PTV::Array(sub_id) => { - serialize_parser_type(buffer, heap, *sub_id); - buffer.push_str("[]"); - }, - PTV::Input(sub_id) => { - buffer.push_str("in<"); - serialize_parser_type(buffer, heap, *sub_id); - buffer.push('>'); - }, - PTV::Output(sub_id) => { - buffer.push_str("out<"); - serialize_parser_type(buffer, heap, *sub_id); - buffer.push('>'); - }, - PTV::Symbolic(symbolic) => { - buffer.push_str(&String::from_utf8_lossy(&symbolic.identifier.value)); - if symbolic.poly_args2.len() > 0 { + fn write_bytes(buffer: &mut String, bytes: &[u8]) { + let utf8 = String::from_utf8_lossy(bytes); + buffer.push_str(&utf8); + } + + fn serialize_variant(buffer: &mut String, heap: &Heap, parser_type: &ParserType, mut idx: usize) -> usize { + match &parser_type.elements[idx].variant { + PTV::Message => write_bytes(buffer, KW_TYPE_MESSAGE), + PTV::Bool => write_bytes(buffer, KW_TYPE_BOOL), + PTV::UInt8 => write_bytes(buffer, KW_TYPE_UINT8), + PTV::UInt16 => write_bytes(buffer, KW_TYPE_UINT16), + PTV::UInt32 => write_bytes(buffer, KW_TYPE_UINT32), + PTV::UInt64 => write_bytes(buffer, KW_TYPE_UINT64), + PTV::SInt8 => write_bytes(buffer, KW_TYPE_SINT8), + PTV::SInt16 => write_bytes(buffer, KW_TYPE_SINT16), + PTV::SInt32 => write_bytes(buffer, KW_TYPE_SINT32), + PTV::SInt64 => write_bytes(buffer, KW_TYPE_SINT64), + PTV::Character => write_bytes(buffer, KW_TYPE_CHAR), + PTV::String => write_bytes(buffer, KW_TYPE_STRING), + PTV::IntegerLiteral => buffer.push_str("int_literal"), + PTV::Inferred => write_bytes(buffer, KW_TYPE_INFERRED), + PTV::Array => { + idx = serialize_variant(buffer, heap, parser_type, idx + 1); + buffer.push_str("[]"); + }, + PTV::Input => { + write_bytes(buffer, KW_TYPE_IN_PORT); buffer.push('<'); - for (poly_idx, poly_arg) in symbolic.poly_args2.iter().enumerate() { - if poly_idx != 0 { buffer.push(','); } - serialize_parser_type(buffer, heap, *poly_arg); - } + idx = serialize_variant(buffer, heap, parser_type, idx + 1); + buffer.push('>'); + }, + PTV::Output => { + write_bytes(buffer, KW_TYPE_OUT_PORT); + buffer.push('<'); + idx = serialize_variant(buffer, heap, parser_type, idx + 1); buffer.push('>'); + }, + PTV::PolymorphicArgument(definition_id, poly_idx) => { + let definition = &heap[*definition_id]; + let poly_arg = &definition.poly_vars()[*poly_idx]; + buffer.push_str(poly_arg.value.as_str()); + }, + PTV::Definition(definition_id, num_embedded) => { + let definition = &heap[*definition_id]; + buffer.push_str(definition.identifier().value.as_str()); + + let num_embedded = *num_embedded; + if num_embedded != 0 { + buffer.push('<'); + for embedded_idx in 0..num_embedded { + if embedded_idx != 0 { + buffer.push(','); + } + idx = serialize_variant(buffer, heap, parser_type, idx + 1); + } + buffer.push('>'); + } } } + + idx } + + serialize_variant(buffer, heap, parser_type, 0); } fn serialize_concrete_type(buffer: &mut String, heap: &Heap, def: DefinitionId, concrete: &ConcreteType) { // Retrieve polymorphic variables - let poly_vars = match &heap[def] { - Definition::Function(definition) => &definition.poly_vars, - Definition::Component(definition) => &definition.poly_vars, - Definition::Struct(definition) => &definition.poly_vars, - Definition::Enum(definition) => &definition.poly_vars, - Definition::Union(definition) => &definition.poly_vars, - }; + let poly_vars = heap[def].poly_vars(); + + fn write_bytes(buffer: &mut String, bytes: &[u8]) { + let utf8 = String::from_utf8_lossy(bytes); + buffer.push_str(&utf8); + } fn serialize_recursive( buffer: &mut String, heap: &Heap, poly_vars: &Vec, concrete: &ConcreteType, mut idx: usize @@ -883,16 +900,21 @@ fn serialize_concrete_type(buffer: &mut String, heap: &Heap, def: DefinitionId, let part = &concrete.parts[idx]; match part { CTP::Marker(poly_idx) => { - buffer.push_str(&String::from_utf8_lossy(&poly_vars[*poly_idx].value)); + buffer.push_str(poly_vars[*poly_idx].value.as_str()); }, CTP::Void => buffer.push_str("void"), - CTP::Message => buffer.push_str("msg"), - CTP::Bool => buffer.push_str("bool"), - CTP::Byte => buffer.push_str("byte"), - CTP::Short => buffer.push_str("short"), - CTP::Int => buffer.push_str("int"), - CTP::Long => buffer.push_str("long"), - CTP::String => buffer.push_str("string"), + CTP::Message => write_bytes(buffer, KW_TYPE_MESSAGE), + CTP::Bool => write_bytes(buffer, KW_TYPE_BOOL), + CTP::UInt8 => write_bytes(buffer, KW_TYPE_UINT8), + CTP::UInt16 => write_bytes(buffer, KW_TYPE_UINT16), + CTP::UInt32 => write_bytes(buffer, KW_TYPE_UINT32), + CTP::UInt64 => write_bytes(buffer, KW_TYPE_UINT64), + CTP::SInt8 => write_bytes(buffer, KW_TYPE_SINT8), + CTP::SInt16 => write_bytes(buffer, KW_TYPE_SINT16), + CTP::SInt32 => write_bytes(buffer, KW_TYPE_SINT32), + CTP::SInt64 => write_bytes(buffer, KW_TYPE_SINT64), + CTP::Character => write_bytes(buffer, KW_TYPE_CHAR), + CTP::String => write_bytes(buffer, KW_TYPE_STRING), CTP::Array => { idx = serialize_recursive(buffer, heap, poly_vars, concrete, idx + 1); buffer.push_str("[]"); @@ -902,18 +924,20 @@ fn serialize_concrete_type(buffer: &mut String, heap: &Heap, def: DefinitionId, buffer.push_str("[..]"); }, CTP::Input => { - buffer.push_str("in<"); + write_bytes(buffer, KW_TYPE_IN_PORT); + buffer.push('<'); idx = serialize_recursive(buffer, heap, poly_vars, concrete, idx + 1); buffer.push('>'); }, CTP::Output => { - buffer.push_str("out<"); + write_bytes(buffer, KW_TYPE_OUT_PORT); + buffer.push('<'); idx = serialize_recursive(buffer, heap, poly_vars, concrete, idx + 1); buffer.push('>'); }, CTP::Instance(definition_id, num_sub) => { let definition_name = heap[*definition_id].identifier(); - buffer.push_str(&String::from_utf8_lossy(&definition_name.value)); + buffer.push_str(definition_name.value.as_str()); if *num_sub != 0 { buffer.push('<'); for sub_idx in 0..*num_sub { @@ -961,15 +985,17 @@ fn seek_stmt bool>(heap: &Heap, start: StatementId, f: &F) }, Statement::Labeled(stmt) => seek_stmt(heap, stmt.body, f), Statement::If(stmt) => { - if let Some(id) = seek_stmt(heap,stmt.true_body, f) { - return Some(id); - } else if let Some(id) = seek_stmt(heap, stmt.false_body, f) { + if let Some(id) = seek_stmt(heap, stmt.true_body.upcast(), f) { return Some(id); + } else if let Some(false_body) = stmt.false_body { + if let Some(id) = seek_stmt(heap, false_body.upcast(), f) { + return Some(id); + } } None }, - Statement::While(stmt) => seek_stmt(heap, stmt.body, f), - Statement::Synchronous(stmt) => seek_stmt(heap, stmt.body, f), + Statement::While(stmt) => seek_stmt(heap, stmt.body.upcast(), f), + Statement::Synchronous(stmt) => seek_stmt(heap, stmt.body.upcast(), f), _ => None }; @@ -1019,14 +1045,6 @@ fn seek_expr_in_expr bool>(heap: &Heap, start: ExpressionI Expression::Select(expr) => { seek_expr_in_expr(heap, expr.subject, f) }, - Expression::Array(expr) => { - for element in &expr.elements { - if let Some(id) = seek_expr_in_expr(heap, *element, f) { - return Some(id) - } - } - None - }, Expression::Literal(expr) => { if let Literal::Struct(lit) = &expr.value { for field in &lit.fields { @@ -1034,6 +1052,12 @@ fn seek_expr_in_expr bool>(heap: &Heap, start: ExpressionI return Some(id) } } + } else if let Literal::Array(elements) = &expr.value { + for element in elements { + if let Some(id) = seek_expr_in_expr(heap, *element, f) { + return Some(id) + } + } } None }, @@ -1069,16 +1093,20 @@ fn seek_expr_in_stmt bool>(heap: &Heap, start: StatementId Statement::If(stmt) => { None .or_else(|| seek_expr_in_expr(heap, stmt.test, f)) - .or_else(|| seek_expr_in_stmt(heap, stmt.true_body, f)) - .or_else(|| seek_expr_in_stmt(heap, stmt.false_body, f)) + .or_else(|| seek_expr_in_stmt(heap, stmt.true_body.upcast(), f)) + .or_else(|| if let Some(false_body) = stmt.false_body { + seek_expr_in_stmt(heap, false_body.upcast(), f) + } else { + None + }) }, Statement::While(stmt) => { None .or_else(|| seek_expr_in_expr(heap, stmt.test, f)) - .or_else(|| seek_expr_in_stmt(heap, stmt.body, f)) + .or_else(|| seek_expr_in_stmt(heap, stmt.body.upcast(), f)) }, Statement::Synchronous(stmt) => { - seek_expr_in_stmt(heap, stmt.body, f) + seek_expr_in_stmt(heap, stmt.body.upcast(), f) }, Statement::Return(stmt) => { seek_expr_in_expr(heap, stmt.expression, f)