diff --git a/src/macros.rs b/src/macros.rs index 461f7bb31a05087ab72bebb6b92c5406f16ad044..61f92de410ed10e4dbb8b993b023a83732777a6f 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -1,3 +1,14 @@ +macro_rules! enabled_debug_print { + (false, $name:literal, $format:literal) => {}; + (false, $name:literal, $format:literal, $($args:expr),*) => {}; + (true, $name:literal, $format:literal) => { + println!("[{}] {}", $name, $format) + }; + (true, $name:literal, $format:literal, $($args:expr),*) => { + println!("[{}] {}", $name, format!($format, $($args),*)) + }; +} + /* Change the definition of these macros to control the logging level statically */ diff --git a/src/protocol/ast.rs b/src/protocol/ast.rs index 829ec525d7d64a682ab17ae0f0ee5cafbbc07d14..de59ceb8784acf851202f9c4f8aa0cfb0ebfa85c 100644 --- a/src/protocol/ast.rs +++ b/src/protocol/ast.rs @@ -1004,6 +1004,16 @@ pub struct LiteralStruct { pub(crate) definition: Option } +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct LiteralEnum { + // Phase 1: parser + pub(crate) identifier: NamespacedIdentifier, + pub(crate) poly_args: Vec, + // Phase 2: linker + pub(crate) definition: Option, + pub(crate) variant_idx: usize, +} + #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub enum Method { Get, @@ -1031,6 +1041,13 @@ impl Field { _ => false, } } + + pub fn as_symbolic(&self) -> &FieldSymbolic { + match self { + Field::Symbolic(v) => v, + _ => unreachable!("attempted to get Field::Symbolic from {:?}", self) + } + } } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] diff --git a/src/protocol/lexer.rs b/src/protocol/lexer.rs index 498b671031126ad6c0a4f759114894e0c9d1d167..ed10f5f71043a5f1c545eb5c289c9eb5422c2786 100644 --- a/src/protocol/lexer.rs +++ b/src/protocol/lexer.rs @@ -4,6 +4,28 @@ use crate::protocol::inputsource::*; const MAX_LEVEL: usize = 128; const MAX_NAMESPACES: u8 = 8; // only three levels are supported at the moment +macro_rules! debug_log { + ($format:literal) => { + enabled_debug_print!(false, "lexer", $format); + }; + ($format:literal, $($args:expr),*) => { + enabled_debug_print!(false, "lexer", $format, $($args),*); + }; +} + +macro_rules! debug_line { + ($source:expr) => { + { + let mut buffer = String::with_capacity(128); + for idx in 0..buffer.capacity() { + let next = $source.lookahead(idx); + if next.is_none() || Some(b'\n') == next { break; } + buffer.push(next.unwrap() as char); + } + buffer + } + }; +} fn is_vchar(x: Option) -> bool { if let Some(c) = x { c >= 0x21 && c <= 0x7E @@ -169,6 +191,10 @@ impl Lexer<'_> { } // Word boundary + let next = self.source.lookahead(keyword.len()); + if next.is_none() { return true; } + return !is_ident_rest(next); + if let Some(next) = self.source.lookahead(keyword.len()) { !(next >= b'A' && next <= b'Z' || next >= b'a' && next <= b'z') } else { @@ -249,34 +275,35 @@ impl Lexer<'_> { Ok(Some(elements)) } /// Essentially the same as `consume_comma_separated`, but will not allocate - /// memory. Will return `true` and leave the input position at the end of - /// the comma-separated list if well formed. Otherwise returns `false` and - /// leaves the input position at a "random" position. + /// memory. Will return `Ok(true)` and leave the input position at the end + /// the comma-separated list if well formed and `Ok(false)` if the list is + /// not present. Otherwise returns `Err(())` and leaves the input position + /// at a "random" position. fn consume_comma_separated_spilled_without_pos_recovery bool>( &mut self, open: u8, close: u8, func: F - ) -> bool { + ) -> Result { if Some(open) != self.source.next() { - return true; + return Ok(false); } self.source.consume(); - if self.consume_whitespace(false).is_err() { return false }; + if self.consume_whitespace(false).is_err() { return Err(()) }; let mut had_comma = true; loop { if Some(close) == self.source.next() { self.source.consume(); - return true; + return Ok(true); } else if !had_comma { - return false; + return Err(()); } - if !func(self) { return false; } - if self.consume_whitespace(false).is_err() { return false }; + if !func(self) { return Err(()); } + if self.consume_whitespace(false).is_err() { return Err(()) }; had_comma = self.source.next() == Some(b','); if had_comma { self.source.consume(); - if self.consume_whitespace(false).is_err() { return false; } + if self.consume_whitespace(false).is_err() { return Err(()); } } } } @@ -480,6 +507,7 @@ impl Lexer<'_> { }; // Consume the type + debug_log!("consume_type2: {}", debug_line!(self.source)); let pos = self.source.pos(); let parser_type_variant = if self.has_keyword(b"msg") { self.consume_keyword(b"msg")?; @@ -600,6 +628,7 @@ impl Lexer<'_> { /// position. fn maybe_consume_type_spilled_without_pos_recovery(&mut self) -> bool { // Consume type identifier + debug_log!("maybe_consume_type_spilled_...: {}", debug_line!(self.source)); if self.has_type_keyword() { self.consume_any_chars(); } else { @@ -610,11 +639,15 @@ impl Lexer<'_> { // Consume any polymorphic arguments that follow the type identifier let mut backup_pos = self.source.pos(); if self.consume_whitespace(false).is_err() { return false; } - if !self.maybe_consume_poly_args_spilled_without_pos_recovery() { return false; } + match self.maybe_consume_poly_args_spilled_without_pos_recovery() { + Ok(true) => backup_pos = self.source.pos(), + Ok(false) => {}, + Err(()) => return false + } + // Consume any array specifiers. Make sure we always leave the input // position at the end of the last array specifier if we do find a // valid type - if self.consume_whitespace(false).is_err() { return false; } while let Some(b'[') = self.source.next() { self.source.consume(); @@ -641,8 +674,10 @@ impl Lexer<'_> { /// Attempts to consume polymorphic arguments without returning them. If it /// doesn't encounter well-formed polymorphic arguments, then the input - /// position is left at a "random" position. - fn maybe_consume_poly_args_spilled_without_pos_recovery(&mut self) -> bool { + /// position is left at a "random" position. Returns a boolean indicating if + /// the poly_args list was present. + fn maybe_consume_poly_args_spilled_without_pos_recovery(&mut self) -> Result { + debug_log!("maybe_consume_poly_args_spilled_...: {}", debug_line!(self.source)); self.consume_comma_separated_spilled_without_pos_recovery( b'<', b'>', |lexer| { lexer.maybe_consume_type_spilled_without_pos_recovery() @@ -1430,7 +1465,7 @@ impl Lexer<'_> { let backup_pos = self.source.pos(); let result = self.consume_namespaced_identifier_spilled().is_ok() && self.consume_whitespace(false).is_ok() && - self.maybe_consume_poly_args_spilled_without_pos_recovery() && + self.maybe_consume_poly_args_spilled_without_pos_recovery().is_ok() && self.consume_whitespace(false).is_ok() && self.source.next() == Some(b'{'); @@ -1440,6 +1475,7 @@ impl Lexer<'_> { fn consume_struct_literal_expression(&mut self, h: &mut Heap) -> Result { // Consume identifier and polymorphic arguments + debug_log!("consume_struct_literal_expression: {}", debug_line!(self.source)); let position = self.source.pos(); let identifier = self.consume_namespaced_identifier()?; self.consume_whitespace(false)?; @@ -1492,7 +1528,7 @@ impl Lexer<'_> { if self.consume_namespaced_identifier_spilled().is_ok() && self.consume_whitespace(false).is_ok() && - self.maybe_consume_poly_args_spilled_without_pos_recovery() && + self.maybe_consume_poly_args_spilled_without_pos_recovery().is_ok() && self.consume_whitespace(false).is_ok() && self.source.next() == Some(b'(') { // Seems like we have a function call or an enum literal @@ -1506,6 +1542,7 @@ impl Lexer<'_> { let position = self.source.pos(); // Consume method identifier + debug_log!("consume_call_expression: {}", debug_line!(self.source)); let method; if self.has_keyword(b"get") { self.consume_keyword(b"get")?; @@ -1564,6 +1601,7 @@ impl Lexer<'_> { h: &mut Heap, ) -> Result { let position = self.source.pos(); + debug_log!("consume_variable_expression: {}", debug_line!(self.source)); let identifier = self.consume_namespaced_identifier()?; Ok(h.alloc_variable_expression(|this| VariableExpression { this, @@ -2394,8 +2432,32 @@ impl Lexer<'_> { module_id: None, symbols: Vec::new() })) + } else if self.has_identifier() { + let position = self.source.pos(); + let name = self.consume_ident()?; + self.consume_whitespace(false)?; + let alias = if self.has_string(b"as") { + self.consume_string(b"as")?; + self.consume_whitespace(true)?; + self.consume_ident()? + } else { + name.clone() + }; + + h.alloc_import(|this| Import::Symbols(ImportSymbols{ + this, + position, + module_name: value, + module_id: None, + symbols: vec![AliasedSymbol{ + position, + name, + alias, + definition_id: None + }] + })) } else { - return Err(self.error_at_pos("Expected '*' or '{'")); + return Err(self.error_at_pos("Expected '*', '{' or a symbol name")); } } else { // No explicit alias or subimports, so implicit alias diff --git a/src/protocol/parser/mod.rs b/src/protocol/parser/mod.rs index c2c572555bbb71d197030a636a359dc4f7912b72..76e660b1d3512c2979416ff1a4ae81e52c3463af 100644 --- a/src/protocol/parser/mod.rs +++ b/src/protocol/parser/mod.rs @@ -175,6 +175,8 @@ impl Parser { } } } + + import_index += 1; } // All imports in the AST are now annotated. We now use the symbol table diff --git a/src/protocol/parser/type_resolver.rs b/src/protocol/parser/type_resolver.rs index 41107292a061244f03241d946db7858f34de9fa4..955da8b3756a28ff7ade6220b696e3c65de5b784 100644 --- a/src/protocol/parser/type_resolver.rs +++ b/src/protocol/parser/type_resolver.rs @@ -27,6 +27,7 @@ /// type checking. /// /// TODO: Needs a thorough rewrite: +/// 0. polymorph_progress is intentionally broken at the moment. /// 1. For polymorphic type inference we need to have an extra datastructure /// for progressing the polymorphic variables and mapping them back to each /// signature type that uses that polymorphic type. The two types of markers @@ -44,17 +45,6 @@ /// 6. Investigate different ways of performing the type-on-type inference, /// maybe there is a better way then flattened trees + markers? -macro_rules! enabled_debug_print { - (false, $name:literal, $format:literal) => {}; - (false, $name:literal, $format:literal, $($args:expr),*) => {}; - (true, $name:literal, $format:literal) => { - println!("[{}] {}", $name, $format) - }; - (true, $name:literal, $format:literal, $($args:expr),*) => { - println!("[{}] {}", $name, format!($format, $($args),*)) - }; -} - macro_rules! debug_log { ($format:literal) => { enabled_debug_print!(true, "types", $format); @@ -1989,6 +1979,9 @@ impl TypeResolvingVisitor { Literal::Character(_) => todo!("character literals"), Literal::Struct(data) => { let extra = self.extra_data.get_mut(&upcast_id).unwrap(); + for poly in &extra.poly_vars { + debug_log!(" * Poly: {}", poly.display_name(&ctx.heap)); + } let mut poly_progress = HashSet::new(); debug_assert_eq!(extra.embedded.len(), data.fields.len()); @@ -2015,6 +2008,8 @@ impl TypeResolvingVisitor { } } + debug_log!(" - Field poly progress | {:?}", poly_progress); + // Same for the type of the struct itself let signature_type: *mut _ = &mut extra.returned; let expr_type: *mut _ = self.expr_types.get_mut(&upcast_id).unwrap(); @@ -2028,6 +2023,7 @@ impl TypeResolvingVisitor { unsafe{&*signature_type}.display_name(&ctx.heap), unsafe{&*expr_type}.display_name(&ctx.heap) ); + debug_log!(" - Ret poly progress | {:?}", poly_progress); if progress_expr { // TODO: @cleanup, cannot call utility self.queue_parent thingo @@ -2452,7 +2448,7 @@ impl TypeResolvingVisitor { /// This function returns true if the expression's type has been progressed fn apply_equal2_polyvar_constraint( heap: &Heap, - polymorph_data: &ExtraData, polymorph_progress: &HashSet, + polymorph_data: &ExtraData, _polymorph_progress: &HashSet, signature_type: *mut InferenceType, expr_type: *mut InferenceType ) -> bool { // Safety: all pointers should be distinct @@ -2468,7 +2464,7 @@ impl TypeResolvingVisitor { while let Some((poly_idx, start_idx)) = signature_type.find_body_marker(seek_idx) { let end_idx = InferenceType::find_subtree_end_idx(&signature_type.parts, start_idx); - if polymorph_progress.contains(&poly_idx) { + // if polymorph_progress.contains(&poly_idx) { // Need to match subtrees let polymorph_type = &polymorph_data.poly_vars[poly_idx]; debug_log!(" - DEBUG: Applying {} to '{}' from '{}'", polymorph_type.display_name(heap), InferenceType::partial_display_name(heap, &signature_type.parts[start_idx..]), signature_type.display_name(heap)); @@ -2478,7 +2474,7 @@ impl TypeResolvingVisitor { ).expect("no failure when applying polyvar constraints"); modified_sig = modified_sig || modified_at_marker; - } + // } seek_idx = end_idx; } @@ -2838,10 +2834,7 @@ impl TypeResolvingVisitor { // Retrieve relevant data let expr = &ctx.heap[select_id]; - let field = match &expr.field { - Field::Symbolic(field) => field, - _ => unreachable!(), - }; + let field = expr.field.as_symbolic(); let definition_id = field.definition.unwrap(); let definition = ctx.heap[definition_id].as_struct(); @@ -3120,12 +3113,11 @@ impl TypeResolvingVisitor { } } - fn get_poly_var_and_literal_name(ctx: &Ctx, poly_var_idx: usize, expr: &LiteralExpression) -> (String, String) { - let expr = expr.value.as_struct(); - let definition = &ctx.heap[expr.definition.unwrap()]; + fn get_poly_var_and_type_name(ctx: &Ctx, poly_var_idx: usize, definition_id: DefinitionId) -> (String, String) { + let definition = &ctx.heap[definition_id]; match definition { Definition::Enum(_) | Definition::Function(_) | Definition::Component(_) => - unreachable!(), + unreachable!("get_poly_var_and_type_name called on non-struct value"), Definition::Struct(definition) => ( String::from_utf8_lossy(&definition.poly_vars[poly_var_idx].value).to_string(), String::from_utf8_lossy(&definition.identifier.value).to_string() @@ -3147,7 +3139,8 @@ impl TypeResolvingVisitor { ) }, Expression::Literal(expr) => { - let (poly_var, struct_name) = get_poly_var_and_literal_name(ctx, poly_var_idx, expr); + let lit_struct = expr.value.as_struct(); + let (poly_var, struct_name) = get_poly_var_and_type_name(ctx, poly_var_idx, lit_struct.definition.unwrap()); return ParseError2::new_error( &ctx.module.source, expr.position(), &format!( @@ -3156,6 +3149,17 @@ impl TypeResolvingVisitor { ) ) }, + Expression::Select(expr) => { + let field = expr.field.as_symbolic(); + let (poly_var, struct_name) = get_poly_var_and_type_name(ctx, poly_var_idx, field.definition.unwrap()); + return ParseError2::new_error( + &ctx.module.source, expr.position(), + &format!( + "Conflicting type for polymorphic variable '{}' while accessing field '{}' of '{}'", + poly_var, &String::from_utf8_lossy(&field.identifier.value), struct_name + ) + ) + } _ => unreachable!("called construct_poly_arg_error without a call/literal expression") } } @@ -3176,6 +3180,13 @@ impl TypeResolvingVisitor { .collect(), "literal" ), + Expression::Select(expr) => + // Select expression uses the polymorphic variables of the + // struct it is accessing, so get the subject expression. + ( + vec![expr.subject], + "selected field" + ), _ => unreachable!(), }; diff --git a/src/protocol/tests/mod.rs b/src/protocol/tests/mod.rs index 2c42b52e7b51356450f51f2d62d9a505be7aaf47..fe6791df28f790a5c8859202d916a18cd7874d18 100644 --- a/src/protocol/tests/mod.rs +++ b/src/protocol/tests/mod.rs @@ -3,5 +3,6 @@ mod lexer; mod parser_validation; mod parser_inference; mod parser_monomorphs; +mod parser_imports; pub(crate) use utils::{Tester}; \ No newline at end of file diff --git a/src/protocol/tests/parser_imports.rs b/src/protocol/tests/parser_imports.rs new file mode 100644 index 0000000000000000000000000000000000000000..27430f1dcbddde5ff7cc25d5f9682eb4e6d4791e --- /dev/null +++ b/src/protocol/tests/parser_imports.rs @@ -0,0 +1,148 @@ +/// parser_imports.rs +/// +/// Simple import tests + +use super::*; + +#[test] +fn test_module_import() { + Tester::new("single domain name") + .with_source(" + #module external + struct Foo { int field } + ") + .with_source(" + import external; + int caller() { + auto a = external::Foo{ field: 0 }; + return a.field; + } + ") + .compile() + .expect_ok(); + + Tester::new("multi domain name") + .with_source(" + #module external.domain + struct Foo { int field } + ") + .with_source(" + import external.domain; + int caller() { + auto a = domain::Foo{ field: 0 }; + return a.field; + } + ") + .compile() + .expect_ok(); + + Tester::new("aliased domain name") + .with_source(" + #module external + struct Foo { int field } + ") + .with_source(" + import external as aliased; + int caller() { + auto a = aliased::Foo{ field: 0 }; + return a.field; + } + ") + .compile() + .expect_ok(); +} + +#[test] +fn test_single_symbol_import() { + Tester::new("specific symbol") + .with_source(" + #module external + struct Foo { int field } + ") + .with_source(" + import external::Foo; + int caller() { + auto a = Foo{ field: 1 }; + auto b = Foo{ field: 2 }; + return a.field + b.field; + }") + .compile() + .expect_ok(); + + Tester::new("specific aliased symbol") + .with_source(" + #module external + struct Foo { int field } + ") + .with_source(" + import external::Foo as Bar; + int caller() { + return Bar{ field: 0 }.field; + } + ") + .compile() + .expect_ok(); + + // TODO: Re-enable once std lib is properly implemented + // Tester::new("import all") + // .with_source(" + // #module external + // struct Foo { int field } + // ") + // .with_source(" + // import external::*; + // int caller() { return Foo{field:0}.field; } + // ") + // .compile() + // .expect_ok(); +} + +#[test] +fn test_multi_symbol_import() { + Tester::new("specific symbols") + .with_source(" + #module external + struct Foo { byte f } + struct Bar { byte b } + ") + .with_source(" + import external::{Foo, Bar}; + byte caller() { + return Foo{f:0}.f + Bar{b:1}.b; + } + ") + .compile() + .expect_ok(); + + Tester::new("aliased symbols") + .with_source(" + #module external + struct Foo { byte in_foo } + struct Bar { byte in_bar } + ") + .with_source(" + import external::{Foo as Bar, Bar as Foo}; + byte caller() { + return Foo{in_bar:0}.in_bar + Bar{in_foo:0}.in_foo; + }") + .compile() + .expect_ok(); + + // TODO: Re-enable once std lib is properly implemented + // Tester::new("import all") + // .with_source(" + // #module external + // struct Foo { byte f }; + // struct Bar { byte b }; + // ") + // .with_source(" + // import external::*; + // byte caller() { + // auto f = Foo{f:0}; + // auto b = Bar{b:0}; + // return f.f + b.b; + // } + // ") + // .compile() + // .expect_ok(); +} \ No newline at end of file diff --git a/src/protocol/tests/parser_inference.rs b/src/protocol/tests/parser_inference.rs index a296a14d08f9f4c4cd211f96293be3d0bfbc980c..88b86aa85187f7807ab4c8546e358121d53e74ef 100644 --- a/src/protocol/tests/parser_inference.rs +++ b/src/protocol/tests/parser_inference.rs @@ -67,6 +67,63 @@ fn test_integer_inference() { }); } +#[test] +fn test_binary_expr_inference() { + Tester::new_single_source_expect_ok( + "compatible types", + "int call() { + byte b0 = 0; + byte b1 = 1; + short s0 = 0; + short s1 = 1; + int i0 = 0; + int i1 = 1; + long l0 = 0; + long l1 = 1; + auto b = b0 + b1; + auto s = s0 + s1; + auto i = i0 + i1; + auto l = l0 + l1; + return i; + }" + ).for_function("call", |f| { f + .for_expression_by_source( + "b0 + b1", "+", + |e| { e.assert_concrete_type("byte"); } + ) + .for_expression_by_source( + "s0 + s1", "+", + |e| { e.assert_concrete_type("short"); } + ) + .for_expression_by_source( + "i0 + i1", "+", + |e| { e.assert_concrete_type("int"); } + ) + .for_expression_by_source( + "l0 + l1", "+", + |e| { e.assert_concrete_type("long"); } + ); + }); + + Tester::new_single_source_expect_err( + "incompatible types", + "int call() { + byte b = 0; + long l = 1; + auto r = b + l; + return 0; + }" + ).error(|e| { e + .assert_ctx_has(0, "b + l") + .assert_msg_has(0, "cannot apply") + .assert_occurs_at(0, "+") + .assert_msg_has(1, "has type 'byte'") + .assert_msg_has(2, "has type 'long'"); + }); +} + + + #[test] fn test_struct_inference() { Tester::new_single_source_expect_ok( @@ -155,4 +212,87 @@ fn test_struct_inference() { .assert_concrete_type("Node>"); }); }); +} + +#[test] +fn test_failed_polymorph_inference() { + Tester::new_single_source_expect_err( + "function call inference mismatch", + " + int poly(T a, T b) { return 0; } + int call() { + byte first_arg = 5; + long second_arg = 2; + return poly(first_arg, second_arg); + } + " + ).error(|e| { e + .assert_num(3) + .assert_ctx_has(0, "poly(first_arg, second_arg)") + .assert_occurs_at(0, "poly") + .assert_msg_has(0, "Conflicting type for polymorphic variable 'T'") + .assert_occurs_at(1, "second_arg") + .assert_msg_has(1, "inferred it to 'long'") + .assert_occurs_at(2, "first_arg") + .assert_msg_has(2, "inferred it to 'byte'"); + }); + + Tester::new_single_source_expect_err( + "struct literal inference mismatch", + " + struct Pair{ T first, T second } + int call() { + byte first_arg = 5; + long second_arg = 2; + auto pair = Pair{ first: first_arg, second: second_arg }; + return 3; + } + " + ).error(|e| { e + .assert_num(3) + .assert_ctx_has(0, "Pair{ first: first_arg, second: second_arg }") + .assert_occurs_at(0, "Pair{") + .assert_msg_has(0, "Conflicting type for polymorphic variable 'T'") + .assert_occurs_at(1, "second_arg") + .assert_msg_has(1, "inferred it to 'long'") + .assert_occurs_at(2, "first_arg") + .assert_msg_has(2, "inferred it to 'byte'"); + }); + + Tester::new_single_source_expect_err( + "field access inference mismatch", + " + struct Holder{ Shazam a } + int call() { + byte to_hold = 0; + auto holder = Holder{ a: to_hold }; + return holder.a; + } + " + ).error(|e| { e + .assert_num(3) + .assert_ctx_has(0, "holder.a") + .assert_occurs_at(0, ".") + .assert_msg_has(0, "Conflicting type for polymorphic variable 'Shazam'") + .assert_msg_has(1, "inferred it to 'byte'") + .assert_msg_has(2, "inferred it to 'int'"); + }); + + // TODO: Needs better error messages anyway, but this failed before + Tester::new_single_source_expect_err( + "by nested field access", + " + struct Node{ T1 l, T2 r } + Node construct(T1 l, T2 r) { return Node{ l: l, r: r }; } + int fix_poly(Node a) { return 0; } + int test() { + byte assigned = 0; + long another = 1; + auto thing = construct(assigned, construct(another, 1)); + fix_poly(thing.r); + thing.r.r = assigned; + return 0; + } + ", + ); } \ No newline at end of file diff --git a/src/protocol/tests/parser_monomorphs.rs b/src/protocol/tests/parser_monomorphs.rs index e3b889c6dc00bd544cb5a0ce6619b210364a81a7..9ba45b85150fd67569da3d58a9138403f0f99023 100644 --- a/src/protocol/tests/parser_monomorphs.rs +++ b/src/protocol/tests/parser_monomorphs.rs @@ -34,5 +34,8 @@ fn test_struct_monomorphs() { .assert_has_monomorph("long") .assert_has_monomorph("Number") .assert_num_monomorphs(5); + }).for_function("instantiator", |f| { f + .for_variable("a", |v| {v.assert_concrete_type("Number");} ) + .for_variable("e", |v| {v.assert_concrete_type("Number>");} ); }); } \ No newline at end of file diff --git a/src/protocol/tests/utils.rs b/src/protocol/tests/utils.rs index 209c50695e626572d8575659c5804a7db511f15a..4c15ce3e30162774258dad5a945057904ac1862b 100644 --- a/src/protocol/tests/utils.rs +++ b/src/protocol/tests/utils.rs @@ -398,6 +398,60 @@ impl<'a> FunctionTester<'a> { self } + /// Finds a specific expression within a function. There are two matchers: + /// one outer matcher (to find a rough indication of the expression) and an + /// inner matcher to find the exact expression. + /// + /// The reason being that, for example, a function's body might be littered + /// with addition symbols, so we first match on "some_var + some_other_var", + /// and then match exactly on "+". + pub(crate) fn for_expression_by_source(self, outer_match: &str, inner_match: &str, f: F) -> Self { + // Seek the expression in the source code + assert!(outer_match.contains(inner_match), "improper testing code"); + + let module = seek_def_in_modules( + &self.ctx.heap, &self.ctx.modules, self.def.this.upcast() + ).unwrap(); + + // Find the first occurrence of the expression after the definition of + // the function, we'll check that it is included in the body later. + let mut outer_match_idx = self.def.position.offset; + while outer_match_idx < module.source.input.len() { + if module.source.input[outer_match_idx..].starts_with(outer_match.as_bytes()) { + break; + } + outer_match_idx += 1 + } + + assert!( + outer_match_idx < module.source.input.len(), + "[{}] Failed to find '{}' within the source that contains {}", + self.ctx.test_name, outer_match, self.assert_postfix() + ); + let inner_match_idx = outer_match_idx + outer_match.find(inner_match).unwrap(); + + // Use the inner match index to find the expression + let expr_id = seek_expr_in_stmt( + &self.ctx.heap, self.def.body, + &|expr| expr.position().offset == inner_match_idx + ); + assert!( + expr_id.is_some(), + "[{}] Failed to find '{}' within the source that contains {} \ + (note: expression was found, but not within the specified function", + self.ctx.test_name, outer_match, self.assert_postfix() + ); + let expr_id = expr_id.unwrap(); + + // We have the expression, call the testing function + let tester = ExpressionTester::new( + self.ctx, self.def.this.upcast(), &self.ctx.heap[expr_id] + ); + f(tester); + + self + } + fn assert_postfix(&self) -> String { format!( "Function{{ name: {} }}", @@ -406,7 +460,6 @@ impl<'a> FunctionTester<'a> { } } - pub(crate) struct VariableTester<'a> { ctx: TestCtx<'a>, definition_id: DefinitionId, @@ -456,6 +509,42 @@ impl<'a> VariableTester<'a> { } } +pub(crate) struct ExpressionTester<'a> { + ctx: TestCtx<'a>, + definition_id: DefinitionId, // of the enclosing function/component + expr: &'a Expression +} + +impl<'a> ExpressionTester<'a> { + fn new( + ctx: TestCtx<'a>, definition_id: DefinitionId, expr: &'a Expression + ) -> Self { + Self{ ctx, definition_id, expr } + } + + pub(crate) fn assert_concrete_type(self, expected: &str) -> Self { + let mut serialized = String::new(); + serialize_concrete_type( + &mut serialized, self.ctx.heap, self.definition_id, + self.expr.get_type() + ); + + assert_eq!( + expected, &serialized, + "[{}] Expected concrete type '{}', but got '{}' for {}", + self.ctx.test_name, expected, &serialized, self.assert_postfix() + ); + self + } + + fn assert_postfix(&self) -> String { + format!( + "Expression{{ debug: {:?} }}", + self.expr + ) + } +} + //------------------------------------------------------------------------------ // Interface for failed compilation //------------------------------------------------------------------------------ @@ -602,9 +691,10 @@ fn serialize_concrete_type(buffer: &mut String, heap: &Heap, def: DefinitionId, // Retrieve polymorphic variables, if present (since we're dealing with a // concrete type we only expect procedure types) let poly_vars = match &heap[def] { - Definition::Function(func) => &func.poly_vars, - Definition::Component(comp) => &comp.poly_vars, - _ => unreachable!("Error in testing utility: did not expect non-procedure type for concrete type serialization"), + Definition::Function(definition) => &definition.poly_vars, + Definition::Component(definition) => &definition.poly_vars, + Definition::Struct(definition) => &definition.poly_vars, + _ => unreachable!("Error in testing utility: unexpected type for concrete type serialization"), }; fn serialize_recursive( @@ -666,6 +756,19 @@ fn serialize_concrete_type(buffer: &mut String, heap: &Heap, def: DefinitionId, serialize_recursive(buffer, heap, poly_vars, concrete, 0); } +fn seek_def_in_modules<'a>(heap: &Heap, modules: &'a [LexedModule], def_id: DefinitionId) -> Option<&'a LexedModule> { + for module in modules { + let root = &heap.protocol_descriptions[module.root_id]; + for definition in &root.definitions { + if *definition == def_id { + return Some(module) + } + } + } + + None +} + fn seek_stmt bool>(heap: &Heap, start: StatementId, f: &F) -> Option { let stmt = &heap[start]; if f(stmt) { return Some(start); }