diff --git a/src/protocol/ast.rs b/src/protocol/ast.rs index d352dabae9db5c6a11858ad0c241bcecf69d87c2..009e3f85c8eb904b8aae1a22cc1301fdb1b9329d 100644 --- a/src/protocol/ast.rs +++ b/src/protocol/ast.rs @@ -1071,6 +1071,7 @@ pub enum Literal { Character(LiteralCharacter), Integer(LiteralInteger), Struct(LiteralStruct), + Enum(LiteralEnum), } impl Literal { @@ -1089,6 +1090,14 @@ impl Literal { unreachable!("Attempted to obtain {:?} as Literal::Struct", self) } } + + pub(crate) fn as_enum(&self) -> &LiteralEnum { + if let Literal::Enum(literal) = self { + literal + } else { + unreachable!("Attempted to obtain {:?} as Literal::Enum", self) + } + } } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] @@ -1106,7 +1115,7 @@ pub struct LiteralStruct { pub(crate) identifier: NamespacedIdentifier, pub(crate) fields: Vec, // Phase 2: linker - pub(crate) poly_args2: Vec, // taken from identifier + pub(crate) poly_args2: Vec, // taken from identifier once linked to a definition pub(crate) definition: Option } @@ -1115,9 +1124,9 @@ pub struct LiteralEnum { // Phase 1: parser pub(crate) identifier: NamespacedIdentifier, // Phase 2: linker - pub(crate) poly_args2: Vec, + pub(crate) poly_args2: Vec, // taken from identifier once linked to a definition pub(crate) definition: Option, - pub(crate) variant_idx: usize, + pub(crate) variant_idx: usize, // as present in the type table } #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] diff --git a/src/protocol/ast_printer.rs b/src/protocol/ast_printer.rs index 316094c00f9b488d9e42b09103a96b0c42f64e63..5a0c23a1af62f4b4e996f4d1a5648c11a9a96186 100644 --- a/src/protocol/ast_printer.rs +++ b/src/protocol/ast_printer.rs @@ -276,15 +276,30 @@ impl ASTWriter { match &heap[def_id] { Definition::Struct(_) => todo!("implement Definition::Struct"), - Definition::Enum(_) => todo!("implement Definition::Enum"), + Definition::Enum(def) => { + self.kv(indent).with_id(PREFIX_ENUM_ID, def.this.0.index) + .with_s_key("DefinitionEnum"); + + self.kv(indent2).with_s_key("Name").with_ascii_val(&def.identifier.value); + for poly_var_id in &def.poly_vars { + self.kv(indent3).with_s_key("PolyVar").with_ascii_val(&poly_var_id.value); + } + + self.kv(indent2).with_s_key("Variants"); + for variant in &def.variants { + self.kv(indent3).with_s_key("Variant"); + self.kv(indent4).with_s_key("Name") + .with_ascii_val(&variant.identifier.value); + // TODO: Attached value + } + }, Definition::Function(def) => { self.kv(indent).with_id(PREFIX_FUNCTION_ID, def.this.0.index) .with_s_key("DefinitionFunction"); self.kv(indent2).with_s_key("Name").with_ascii_val(&def.identifier.value); for poly_var_id in &def.poly_vars { - self.kv(indent3).with_s_key("PolyVar"); - self.kv(indent4).with_s_key("Name").with_ascii_val(&poly_var_id.value); + self.kv(indent3).with_s_key("PolyVar").with_ascii_val(&poly_var_id.value); } self.kv(indent2).with_s_key("ReturnParserType").with_custom_val(|s| write_parser_type(s, heap, &heap[def.return_type])); @@ -306,8 +321,7 @@ impl ASTWriter { self.kv(indent2).with_s_key("PolymorphicVariables"); for poly_var_id in &def.poly_vars { - self.kv(indent3).with_s_key("PolyVar"); - self.kv(indent4).with_s_key("Name").with_ascii_val(&poly_var_id.value); + self.kv(indent3).with_s_key("PolyVar").with_ascii_val(&poly_var_id.value); } self.kv(indent2).with_s_key("Parameters"); @@ -645,6 +659,19 @@ impl ASTWriter { self.kv(indent4).with_s_key("ParserType"); self.write_expr(heap, field.value, indent4 + 1); } + }, + Literal::Enum(data) => { + val.with_s_val("Enum"); + let indent4 = indent3 + 1; + + // Polymorphic arguments + if !data.poly_args2.is_empty() { + self.kv(indent3).with_s_key("PolymorphicArguments"); + for poly_arg in &data.poly_args2 { + self.kv(indent4).with_s_key("Argument") + .with_custom_val(|v| write_parser_type(v, heap, &heap[*poly_arg])); + } + } } } diff --git a/src/protocol/eval.rs b/src/protocol/eval.rs index 0d37e152b5c05da853162089be83102d9ce49477..a337ef11a732ee6b50a149b64e7d6ea2c77f31ac 100644 --- a/src/protocol/eval.rs +++ b/src/protocol/eval.rs @@ -88,6 +88,7 @@ impl Value { } Literal::Character(_data) => unimplemented!(), Literal::Struct(_data) => unimplemented!(), + Literal::Enum(_data) => unimplemented!(), } } fn set(&mut self, index: &Value, value: &Value) -> Option { diff --git a/src/protocol/lexer.rs b/src/protocol/lexer.rs index f31625a0588a6bd1b3a7e2641d7dee81d138a572..bb7f61a98143ad974a9866dada82acfe546b8400 100644 --- a/src/protocol/lexer.rs +++ b/src/protocol/lexer.rs @@ -514,7 +514,9 @@ impl Lexer<'_> { Ok(ident) } - fn consume_namespaced_identifier_spilled(&mut self) -> Result<(), ParseError> { + // Consumes a spilled namespaced identifier and returns the number of + // namespaces that we encountered. + fn consume_namespaced_identifier_spilled(&mut self) -> Result { if self.has_reserved() { return Err(self.error_at_pos("Encountered reserved keyword")); } @@ -534,6 +536,7 @@ impl Lexer<'_> { } let mut backup_pos = self.source.pos(); + let mut num_namespaces = 1; consume_part_spilled(self, &mut backup_pos)?; self.consume_whitespace(false)?; while self.has_string(b"::") { @@ -541,10 +544,11 @@ impl Lexer<'_> { self.consume_whitespace(false)?; consume_part_spilled(self, &mut backup_pos)?; self.consume_whitespace(false)?; + num_namespaces += 1; } self.source.seek(backup_pos); - Ok(()) + Ok(num_namespaces) } // Types and type annotations @@ -1436,6 +1440,9 @@ impl Lexer<'_> { if self.has_call_expression() { return Ok(self.consume_call_expression(h)?.upcast()); } + if self.has_enum_literal() { + return Ok(self.consume_enum_literal(h)?.upcast()); + } Ok(self.consume_variable_expression(h)?.upcast()) } fn consume_array_expression(&mut self, h: &mut Heap) -> Result { @@ -1513,7 +1520,37 @@ impl Lexer<'_> { concrete_type: ConcreteType::default(), })) } - + fn has_enum_literal(&mut self) -> bool { + // An enum literal is always: + // maybe_a_namespace::EnumName::Variant + // So may for now be distinguished from other literals/variables by + // first checking for struct literals and call expressions, then for + // enum literals, finally for variable expressions. It is different + // from a variable expression in that it _always_ contains multiple + // elements to the enum. + let backup_pos = self.source.pos(); + let result = match self.consume_namespaced_identifier_spilled() { + Ok(num_namespaces) => num_namespaces > 1, + Err(_) => false, + }; + self.source.seek(backup_pos); + result + } + fn consume_enum_literal(&mut self, h: &mut Heap) -> Result { + let identifier = self.consume_namespaced_identifier(h)?; + Ok(h.alloc_literal_expression(|this| LiteralExpression{ + this, + position: identifier.position, + value: Literal::Enum(LiteralEnum{ + identifier, + poly_args2: Vec::new(), + definition: None, + variant_idx: 0, + }), + parent: ExpressionParent::None, + concrete_type: ConcreteType::default(), + })) + } fn has_struct_literal(&mut self) -> bool { // A struct literal is written as: // namespace::StructName{ field: expr } @@ -2234,8 +2271,12 @@ impl Lexer<'_> { lexer.consume_string(b")")?; EnumVariantValue::Type(embedded_type) }, + Some(b'}') => { + // End of enum + EnumVariantValue::None + } _ => { - return Err(lexer.error_at_pos("Expected ',', '=', or '('")); + return Err(lexer.error_at_pos("Expected ',', '=', '}' or '('")); } }; diff --git a/src/protocol/parser/mod.rs b/src/protocol/parser/mod.rs index 2ecd781606c3c13ddfd54b1bd62533da5564bc26..f05b8faf04146540c90c2543161b34058fd56fe8 100644 --- a/src/protocol/parser/mod.rs +++ b/src/protocol/parser/mod.rs @@ -234,9 +234,9 @@ impl Parser { } } - // let mut writer = ASTWriter::new(); - // let mut file = std::fs::File::create(std::path::Path::new("ast.txt")).unwrap(); - // writer.write_ast(&mut file, &self.heap); + let mut writer = ASTWriter::new(); + let mut file = std::fs::File::create(std::path::Path::new("ast.txt")).unwrap(); + writer.write_ast(&mut file, &self.heap); Ok(()) } diff --git a/src/protocol/parser/type_resolver.rs b/src/protocol/parser/type_resolver.rs index a964b54db36f6d55fc500238aee0b9caf9e73061..b381df2e8353835bc246adce4e239aecac4914a4 100644 --- a/src/protocol/parser/type_resolver.rs +++ b/src/protocol/parser/type_resolver.rs @@ -245,8 +245,8 @@ impl InferenceType { if cfg!(debug_assertions) { debug_assert!(!parts.is_empty()); if !has_body_marker { - debug_assert!(parts.iter().all(|v| { - if let InferenceTypePart::MarkerBody(_) = v { false } else { true } + debug_assert!(!parts.iter().any(|v| { + if let InferenceTypePart::MarkerBody(_) = v { true } else { false } })); } if is_done { @@ -1224,6 +1224,13 @@ impl Visitor2 for TypeResolvingVisitor { for expr_id in expr_ids { self.visit_expr(ctx, expr_id)?; } + }, + Literal::Enum(_) => { + // Enumerations do not carry any subexpressions, but may still + // have a user-defined polymorphic marker variable. For this + // reason we may still have to apply inference to this + // polymorphic variable + self.insert_initial_enum_polymorph_data(ctx, id); } } @@ -1408,8 +1415,11 @@ impl TypeResolvingVisitor { } }, Expression::Literal(lit_expr) => { - let lit_struct = lit_expr.value.as_struct(); - let definition_id = lit_struct.definition.as_ref().unwrap(); + let definition_id = match &lit_expr.value { + Literal::Struct(literal) => literal.definition.as_ref().unwrap(), + Literal::Enum(literal) => literal.definition.as_ref().unwrap(), + _ => unreachable!("post-inference monomorph for non-struct, non-enum literal") + }; if !ctx.types.has_monomorph(definition_id, &monomorph_types) { ctx.types.add_monomorph(definition_id, monomorph_types); } @@ -2062,8 +2072,44 @@ impl TypeResolvingVisitor { let signature_type: *mut _ = &mut extra.returned; let expr_type: *mut _ = self.expr_types.get_mut(&upcast_id).unwrap(); - let progress_expr = Self::apply_equal2_polyvar_constraint(&ctx.heap, - extra, &poly_progress, signature_type, expr_type + let progress_expr = Self::apply_equal2_polyvar_constraint( + &ctx.heap, extra, &poly_progress, signature_type, expr_type + ); + + progress_expr + }, + Literal::Enum(_) => { + let extra = self.extra_data.get_mut(&upcast_id).unwrap(); + for poly in &extra.poly_vars { + debug_log!(" * Poly: {}", poly.display_name(&ctx.heap)); + } + let mut poly_progress = HashSet::new(); + + debug_log!(" * During (inferring types from return type)"); + + let signature_type: *mut _ = &mut extra.returned; + let expr_type: *mut _ = self.expr_types.get_mut(&upcast_id).unwrap(); + let (_, progress_expr) = Self::apply_equal2_signature_constraint( + ctx, upcast_id, None, extra, &mut poly_progress, + signature_type, 0, expr_type, 0 + )?; + + debug_log!( + " - Ret type | sig: {}, expr: {}", + unsafe{&*signature_type}.display_name(&ctx.heap), + unsafe{&*expr_type}.display_name(&ctx.heap) + ); + + if progress_expr { + // TODO: @cleanup + if let Some(parent_id) = ctx.heap[upcast_id].parent_expr_id() { + self.expr_queued.insert(parent_id); + } + } + + debug_log!(" * During (reinferring from progress polyvars):"); + let progress_expr = Self::apply_equal2_polyvar_constraint( + &ctx.heap, extra, &poly_progress, signature_type, expr_type ); progress_expr @@ -2788,7 +2834,7 @@ impl TypeResolvingVisitor { }; // Note: programmer is capable of specifying fields in a struct literal - // in a different order than on the definition. We take the programmer- + // in a different order than on the definition. We take the literal- // specified order to be leading. let mut embedded_types = Vec::with_capacity(definition.fields.len()); for lit_field in literal.fields.iter() { @@ -2825,6 +2871,46 @@ impl TypeResolvingVisitor { }); } + /// Inserts the extra polymorphic data struct for enum expressions. These + fn insert_initial_enum_polymorph_data( + &mut self, ctx: &Ctx, lit_id: LiteralExpressionId + ) { + use InferenceTypePart as ITP; + let literal = ctx.heap[lit_id].value.as_enum(); + + // Handle polymorphic arguments to the enum + let mut poly_vars = Vec::with_capacity(literal.poly_args2.len()); + let mut total_num_poly_parts = 0; + for poly_arg_type_id in literal.poly_args2.clone() { // TODO: @performance + let inference_type = self.determine_inference_type_from_parser_type( + ctx, poly_arg_type_id, true + ); + total_num_poly_parts += inference_type.parts.len(); + poly_vars.push(inference_type); + } + + // Handle enum type itself + let parts_reserved = 1 + poly_vars.len() + total_num_poly_parts; + let mut parts = Vec::with_capacity(parts_reserved); + parts.push(ITP::Instance(literal.definition.unwrap(), poly_vars.len())); + let mut enum_type_done = true; + for (poly_var_idx, poly_var) in poly_vars.iter().enumerate() { + if !poly_var.is_done { enum_type_done = false; } + + parts.push(ITP::MarkerBody(poly_var_idx)); + parts.extend(poly_var.parts.iter().cloned()); + } + + debug_assert_eq!(parts.len(), parts_reserved); + let enum_type = InferenceType::new(true, enum_type_done, parts); + + self.extra_data.insert(lit_id.upcast(), ExtraData{ + poly_vars, + embedded: Vec::new(), + returned: enum_type, + }); + } + /// Inserts the extra polymorphic data struct. Assumes that the select /// expression's referenced (definition_id, field_idx) has been resolved. fn insert_initial_select_polymorph_data( diff --git a/src/protocol/parser/type_table.rs b/src/protocol/parser/type_table.rs index d61f25c0d6159963e1cd11de99c21b5b310267d2..86cd109ebc04b44a443f725934a3f9b527a4ffc4 100644 --- a/src/protocol/parser/type_table.rs +++ b/src/protocol/parser/type_table.rs @@ -168,6 +168,13 @@ impl DefinedTypeVariant { _ => unreachable!("Cannot convert {} to struct variant", self.type_class()) } } + + pub(crate) fn as_enum(&self) -> &EnumType { + match self { + DefinedTypeVariant::Enum(v) => v, + _ => unreachable!("Cannot convert {} to enum variant", self.type_class()) + } + } } /// `EnumType` is the classical C/C++ enum type. It has various variants with @@ -176,14 +183,14 @@ impl DefinedTypeVariant { /// value multiple times, we assume the user is an expert and we consider both /// variants to be equal to one another. pub struct EnumType { - variants: Vec, - representation: PrimitiveType, + pub(crate) variants: Vec, + pub(crate) representation: PrimitiveType, } // TODO: Also support maximum u64 value pub struct EnumVariant { - identifier: Identifier, - value: i64, + pub(crate) identifier: Identifier, + pub(crate) value: i64, } /// `UnionType` is the algebraic datatype (or sum type, or discriminated union). diff --git a/src/protocol/parser/utils.rs b/src/protocol/parser/utils.rs index 8621e7adcc374abaa948e58dd96cc7cc65c1c649..265daac2f257ca0481ff4b47e3bd354d30fb7654 100644 --- a/src/protocol/parser/utils.rs +++ b/src/protocol/parser/utils.rs @@ -16,6 +16,8 @@ pub(crate) enum FindTypeResult<'t, 'i> { } // TODO: @cleanup Find other uses of this pattern +// TODO: Hindsight is 20/20: this belongs in the visitor_linker, not in a +// separate file. impl<'t, 'i> FindTypeResult<'t, 'i> { /// Utility function to transform the `FindTypeResult` into a `Result` where /// `Ok` contains the resolved type, and `Err` contains a `ParseError` which diff --git a/src/protocol/parser/visitor_linker.rs b/src/protocol/parser/visitor_linker.rs index a28d42f6c9b26a4e15417b86dd3526c1070ad76c..aa7b9ee5f5df6edf27635f665f93cf9da68dd0ec 100644 --- a/src/protocol/parser/visitor_linker.rs +++ b/src/protocol/parser/visitor_linker.rs @@ -699,6 +699,7 @@ impl Visitor2 for ValidityAndLinkerVisitor { debug_assert!(!self.performing_breadth_pass); const FIELD_NOT_FOUND_SENTINEL: usize = usize::max_value(); + const VARIANT_NOT_FOUND_SENTINEL: usize = FIELD_NOT_FOUND_SENTINEL; let constant_expr = &mut ctx.heap[id]; let old_expr_parent = self.expr_parent; @@ -790,6 +791,43 @@ impl Visitor2 for ValidityAndLinkerVisitor { } self.expression_buffer.truncate(old_num_exprs); + }, + Literal::Enum(literal) => { + let upcast_id = id.upcast(); + + // Retrieve and set type of enumeration + let (definition, ident_iter) = self.find_symbol_of_type_variant( + &ctx.module.source, ctx.module.root_id, &ctx.symbols, &ctx.types, + &literal.identifier, TypeClass::Enum + )?; + literal.definition = Some(definition.ast_definition); + + // Make sure the variant exists + let (variant_ident, _) = ident_iter.prev().unwrap(); + let enum_definition = definition.definition.as_enum(); + literal.variant_idx = VARIANT_NOT_FOUND_SENTINEL; + + for (variant_idx, variant) in enum_definition.variants.iter().enumerate() { + if variant.identifier.value == variant_ident { + literal.variant_idx = variant_idx; + break; + } + } + + if literal.variant_idx == VARIANT_NOT_FOUND_SENTINEL { + let variant = String::from_utf8_lossy(variant_ident).to_string(); + let literal = ctx.heap[id].value.as_enum(); + let enum_definition = ctx.heap[definition.ast_definition].as_enum(); + return Err(ParseError::new_error( + &ctx.module.source, literal.identifier.position, + &format!( + "The variant '{}' does not exist on the enum '{}'", + &variant, &String::from_utf8_lossy(&enum_definition.identifier.value) + ) + )) + } + + self.visit_literal_poly_args(ctx, id)?; } } @@ -1371,10 +1409,11 @@ impl ValidityAndLinkerVisitor { /// a definition of a particular type. // Note: root_id, symbols and types passed in explicitly to prevent // borrowing errors - fn find_symbol_of_type<'a>( - &self, source: &InputSource, root_id: RootId, symbols: &SymbolTable, types: &'a TypeTable, - identifier: &NamespacedIdentifier, expected_type_class: TypeClass - ) -> Result<&'a DefinedType, ParseError> { + fn find_symbol_of_type<'t>( + &self, source: &InputSource, root_id: RootId, symbols: &SymbolTable, + types: &'t TypeTable, identifier: &NamespacedIdentifier, + expected_type_class: TypeClass + ) -> Result<&'t DefinedType, ParseError> { // Find symbol associated with identifier let (find_result, _) = find_type_definition(symbols, types, root_id, identifier) .as_parse_error(source)?; @@ -1393,6 +1432,90 @@ impl ValidityAndLinkerVisitor { Ok(find_result) } + /// Finds a particular enum/union using a namespaced identifier. We allow + /// for one more element to exist in the namespaced identifier that + /// supposedly resolves to the enum/union variant. + fn find_symbol_of_type_variant<'t, 'i>( + &self, source: &InputSource, root_id: RootId, symbols: &SymbolTable, + types: &'t TypeTable, identifier: &'i NamespacedIdentifier, + expected_type_class: TypeClass + ) -> Result<(&'t DefinedType, NamespacedIdentifierIter<'i>), ParseError> { + debug_assert!(expected_type_class == TypeClass::Enum || expected_type_class == TypeClass::Union); + let (symbol, mut ident_iter) = symbols.resolve_namespaced_identifier(root_id, identifier); + + if symbol.is_none() { + return Err(ParseError::new_error( + source, identifier.position, + "Could not resolve this identifier to a symbol" + )); + } + + let symbol = symbol.unwrap(); + match symbol.symbol { + Symbol::Namespace(_) => { + return Err(ParseError::new_error( + source, identifier.position, + "This identifier was resolved to a namespace instead of a type" + )) + }, + Symbol::Definition((_, definition_id)) => { + let definition = types.get_base_definition(&definition_id); + debug_assert!(definition.is_some()); + let definition = definition.unwrap(); + + let definition_type_class = definition.definition.type_class(); + if expected_type_class != definition_type_class { + return Err(ParseError::new_error( + source, identifier.position, + &format!( + "Expected to find a {}, this symbols points to a {}", + expected_type_class, definition_type_class + ) + )); + } + + // Make sure we have a variant (that doesn't contain any + // polymorphic args) + let next_part = ident_iter.next(); + if next_part.is_none() { + return Err(ParseError::new_error( + source, identifier.position, + &format!( + "This identifier points to the type '{}', did you mean to instantiate a variant?", + String::from_utf8_lossy(&identifier.value) + ) + )); + } + let (_, next_polyargs) = next_part.unwrap(); + + // Now we make sure that there aren't even more identifiers, and + // make sure that the variant does not contain any polymorphic + // arguments. In that case we can simplify the later visit of + // the (optional) polymorphic args of the enum. + let returned_section = ident_iter.returned_section(); + if ident_iter.num_remaining() != 0 { + return Err(ParseError::new_error( + source, identifier.position, + &format!( + "Too many identifiers, did you mean to write '{}'", + &String::from_utf8_lossy(returned_section) + ) + )); + } + + if next_polyargs.is_some() { + return Err(ParseError::new_error( + source, identifier.position, + "Encountered polymorphic args to an enum variant. These can only be specified for the enum type." + )) + } + + // We're absolutely fine + Ok((definition, ident_iter)) + } + } + } + /// This function will check if the provided while statement ID has a block /// statement that is one of our current parents. fn has_parent_while_scope(&self, ctx: &Ctx, id: WhileStatementId) -> bool { @@ -1566,18 +1689,26 @@ impl ValidityAndLinkerVisitor { fn visit_literal_poly_args(&mut self, ctx: &mut Ctx, lit_id: LiteralExpressionId) -> VisitorResult { // TODO: @token Revisit when tokenizer is implemented let literal_expr = &mut ctx.heap[lit_id]; - if let Literal::Struct(literal) = &mut literal_expr.value { - literal.poly_args2.extend(&literal.identifier.poly_args); + match &mut literal_expr.value { + Literal::Struct(literal) => { + literal.poly_args2.extend(&literal.identifier.poly_args); + }, + Literal::Enum(literal) => { + literal.poly_args2.extend(&literal.identifier.poly_args); + }, + _ => { + debug_assert!(false, "called visit_literal_poly_args on a non-polymorphic literal"); + unreachable!(); + } } let literal_expr = &ctx.heap[lit_id]; let literal_pos = literal_expr.position; - let num_poly_args_to_infer = match &literal_expr.value { + let (num_poly_args_to_infer, poly_args_to_visit) = match &literal_expr.value { Literal::Null | Literal::False | Literal::True | Literal::Character(_) | Literal::Integer(_) => { // Not really an error, but a programmer error as we're likely // doing work twice - debug_assert!(false, "called visit_literal_poly_args on a non-polymorphic literal"); unreachable!(); }, Literal::Struct(literal) => { @@ -1590,19 +1721,30 @@ impl ValidityAndLinkerVisitor { defined_type, maybe_poly_args, literal.identifier.position ).as_parse_error(&ctx.heap, &ctx.module.source)?; - // Visit all specified parser types - let old_num_types = self.parser_type_buffer.len(); - self.parser_type_buffer.extend(&literal.poly_args2); - while self.parser_type_buffer.len() > old_num_types { - let parser_type_id = self.parser_type_buffer.pop().unwrap(); - self.visit_parser_type(ctx, parser_type_id)?; - } - self.parser_type_buffer.truncate(old_num_types); + (num_to_infer, &literal.poly_args2) + }, + Literal::Enum(literal) => { + let defined_type = ctx.types.get_base_definition(literal.definition.as_ref().unwrap()) + .unwrap(); + let maybe_poly_args = literal.identifier.get_poly_args(); + let num_to_infer = match_polymorphic_args_to_vars( + defined_type, maybe_poly_args, literal.identifier.position + ).as_parse_error(&ctx.heap, &ctx.module.source)?; - num_to_infer + println!("DEBUG: poly args 2: {:?}", &literal.poly_args2); + (num_to_infer, &literal.poly_args2) } }; + // Visit all specified parser types in the polymorphic arguments + let old_num_types = self.parser_type_buffer.len(); + self.parser_type_buffer.extend(poly_args_to_visit); + while self.parser_type_buffer.len() > old_num_types { + let parser_type_id = self.parser_type_buffer.pop().unwrap(); + self.visit_parser_type(ctx, parser_type_id)?; + } + self.parser_type_buffer.truncate(old_num_types); + if num_poly_args_to_infer != 0 { for _ in 0..num_poly_args_to_infer { self.parser_type_buffer.push(ctx.heap.alloc_parser_type(|this| ParserType{ @@ -1610,13 +1752,14 @@ impl ValidityAndLinkerVisitor { })); } - let literal = match &mut ctx.heap[lit_id].value { - Literal::Struct(literal) => literal, + let poly_args = match &mut ctx.heap[lit_id].value { + Literal::Struct(literal) => &mut literal.poly_args2, + Literal::Enum(literal) => &mut literal.poly_args2, _ => unreachable!(), }; - literal.poly_args2.reserve(num_poly_args_to_infer); + poly_args.reserve(num_poly_args_to_infer); for _ in 0..num_poly_args_to_infer { - literal.poly_args2.push(self.parser_type_buffer.pop().unwrap()); + poly_args.push(self.parser_type_buffer.pop().unwrap()); } } diff --git a/src/protocol/tests/parser_inference.rs b/src/protocol/tests/parser_inference.rs index 88b86aa85187f7807ab4c8546e358121d53e74ef..a3eecc0b65a4ddd0e140b9942a34c4967c498e20 100644 --- a/src/protocol/tests/parser_inference.rs +++ b/src/protocol/tests/parser_inference.rs @@ -214,6 +214,98 @@ fn test_struct_inference() { }); } +#[test] +fn test_enum_inference() { + Tester::new_single_source_expect_ok( + "no polymorphic vars", + " + enum Choice { A, B } + int test_instances() { + auto foo = Choice::A; + auto bar = Choice::B; + return 0; + } + " + ).for_function("test_instances", |f| { f + .for_variable("foo", |v| { v + .assert_parser_type("auto") + .assert_concrete_type("Choice"); + }) + .for_variable("bar", |v| { v + .assert_parser_type("auto") + .assert_concrete_type("Choice"); + }); + }); + + Tester::new_single_source_expect_ok( + "one polymorphic var", + " + enum Choice{ + A, + B, + } + int fix_as_byte(Choice arg) { return 0; } + int fix_as_int(Choice arg) { return 0; } + int test_instances() { + auto choice_byte = Choice::A; + auto choice_int1 = Choice::B; + Choice choice_int2 = Choice::B; + fix_as_byte(choice_byte); + fix_as_int(choice_int1); + return fix_as_int(choice_int2); + } + " + ).for_function("test_instances", |f| { f + .for_variable("choice_byte", |v| { v + .assert_parser_type("auto") + .assert_concrete_type("Choice"); + }) + .for_variable("choice_int1", |v| { v + .assert_parser_type("auto") + .assert_concrete_type("Choice"); + }) + .for_variable("choice_int2", |v| { v + .assert_parser_type("Choice") + .assert_concrete_type("Choice"); + }); + }); + + Tester::new_single_source_expect_ok( + "two polymorphic vars", + " + enum Choice{ A, B, } + int fix_t1(Choice arg) { return 0; } + int fix_t2(Choice arg) { return 0; } + int test_instances() { + Choice choice1 = Choice::A; + Choice choice2 = Choice::A; + Choice choice3 = Choice::B; + auto choice4 = Choice::B; + fix_t1(choice1); fix_t1(choice2); fix_t1(choice3); fix_t1(choice4); + fix_t2(choice1); fix_t2(choice2); fix_t2(choice3); fix_t2(choice4); + return 0; + } + " + ).for_function("test_instances", |f| { f + .for_variable("choice1", |v| { v + .assert_parser_type("Choice") + .assert_concrete_type("Choice"); + }) + .for_variable("choice2", |v| { v + .assert_parser_type("Choice") + .assert_concrete_type("Choice"); + }) + .for_variable("choice3", |v| { v + .assert_parser_type("Choice") + .assert_concrete_type("Choice"); + }) + .for_variable("choice4", |v| { v + .assert_parser_type("auto") + .assert_concrete_type("Choice"); + }); + }); +} + #[test] fn test_failed_polymorph_inference() { Tester::new_single_source_expect_err( diff --git a/src/protocol/tests/parser_monomorphs.rs b/src/protocol/tests/parser_monomorphs.rs index 9ba45b85150fd67569da3d58a9138403f0f99023..9d5ccafa4f6b43c49bdc4448543c06a218ab4d35 100644 --- a/src/protocol/tests/parser_monomorphs.rs +++ b/src/protocol/tests/parser_monomorphs.rs @@ -38,4 +38,36 @@ fn test_struct_monomorphs() { .for_variable("a", |v| {v.assert_concrete_type("Number");} ) .for_variable("e", |v| {v.assert_concrete_type("Number>");} ); }); +} + +#[test] +fn test_enum_monomorphs() { + Tester::new_single_source_expect_ok( + "no polymorph", + " + enum Answer{ Yes, No } + int do_it() { auto a = Answer::Yes; return 0; } + " + ).for_enum("Answer", |e| { e + .assert_num_monomorphs(0); + }); + + Tester::new_single_source_expect_ok( + "single polymorph", + " + enum Answer { Yes, No } + int instantiator() { + auto a = Answer::Yes; + auto b = Answer::No; + auto c = Answer::Yes; + auto d = Answer>>::No; + return 0; + } + " + ).for_enum("Answer", |e| { e + .assert_num_monomorphs(3) + .assert_has_monomorph("byte") + .assert_has_monomorph("int") + .assert_has_monomorph("Answer>"); + }); } \ No newline at end of file diff --git a/src/protocol/tests/utils.rs b/src/protocol/tests/utils.rs index f5aaff9b5281820d684fff34c06aab35d63ddea9..5b1c07dd45fe71c3196c467b4d7bacd49b7c2e97 100644 --- a/src/protocol/tests/utils.rs +++ b/src/protocol/tests/utils.rs @@ -161,6 +161,31 @@ impl AstOkTester { unreachable!() } + pub(crate) fn for_enum(self, name: &str, f: F) -> Self { + let mut found = false; + for definition in self.heap.definitions.iter() { + if let Definition::Enum(definition) = definition { + if String::from_utf8_lossy(&definition.identifier.value) != name { + continue; + } + + // Found enum with the same name + let tester = EnumTester::new(self.ctx(), definition); + f(tester); + found = true; + break; + } + } + + if found { return self } + + assert!( + false, "[{}] Failed to find definition for enum '{}'", + self.test_name, name + ); + unreachable!() + } + pub(crate) fn for_function(self, name: &str, f: F) -> Self { let mut found = false; for definition in self.heap.definitions.iter() { @@ -221,11 +246,10 @@ impl<'a> StructTester<'a> { } pub(crate) fn assert_num_monomorphs(self, num: usize) -> Self { - let type_def = self.ctx.types.get_base_definition(&self.def.this.upcast()).unwrap(); - assert_eq!( - num, type_def.monomorphs.len(), - "[{}] Expected {} monomorphs, but found {} for {}", - self.ctx.test_name, num, type_def.monomorphs.len(), self.assert_postfix() + let (is_equal, num_encountered) = has_equal_num_monomorphs(self.ctx, num, self.def.this.upcast()); + assert!( + is_equal, "[{}] Expected {} monomorphs, but got {} for {}", + self.ctx.test_name, num, num_encountered, self.assert_postfix() ); self } @@ -233,35 +257,10 @@ impl<'a> StructTester<'a> { /// Asserts that a monomorph exist, separate polymorphic variable types by /// a semicolon. pub(crate) fn assert_has_monomorph(self, serialized_monomorph: &str) -> Self { - let definition_id = self.def.this.upcast(); - let type_def = self.ctx.types.get_base_definition(&definition_id).unwrap(); - - let mut full_buffer = String::new(); - full_buffer.push('['); - for (monomorph_idx, monomorph) in type_def.monomorphs.iter().enumerate() { - let mut buffer = String::new(); - for (element_idx, monomorph_element) in monomorph.iter().enumerate() { - if element_idx != 0 { buffer.push(';'); } - serialize_concrete_type(&mut buffer, self.ctx.heap, definition_id, monomorph_element); - } - - if buffer == serialized_monomorph { - // Found an exact match - return self - } - - if monomorph_idx != 0 { - full_buffer.push_str(", "); - } - full_buffer.push('"'); - full_buffer.push_str(&buffer); - full_buffer.push('"'); - } - full_buffer.push(']'); - + let (has_monomorph, serialized) = has_monomorph(self.ctx, self.def.this.upcast(), serialized_monomorph); assert!( - false, "[{}] Expected to find monomorph {}, but got {} for {}", - self.ctx.test_name, serialized_monomorph, &full_buffer, self.assert_postfix() + has_monomorph, "[{}] Expected to find monomorph {}, but got {} for {}", + self.ctx.test_name, serialized_monomorph, &serialized, self.assert_postfix() ); self } @@ -328,6 +327,57 @@ impl<'a> StructFieldTester<'a> { } } +pub(crate) struct EnumTester<'a> { + ctx: TestCtx<'a>, + def: &'a EnumDefinition, +} + +impl<'a> EnumTester<'a> { + fn new(ctx: TestCtx<'a>, def: &'a EnumDefinition) -> Self { + Self{ ctx, def } + } + + pub(crate) fn assert_num_variants(self, num: usize) -> Self { + assert_eq!( + num, self.def.variants.len(), + "[{}] Expected {} enum variants, but found {} for {}", + self.ctx.test_name, num, self.def.variants.len(), self.assert_postfix() + ); + self + } + + pub(crate) fn assert_num_monomorphs(self, num: usize) -> Self { + let (is_equal, num_encountered) = has_equal_num_monomorphs(self.ctx, num, self.def.this.upcast()); + assert!( + is_equal, "[{}] Expected {} monomorphs, but got {} for {}", + self.ctx.test_name, num, num_encountered, self.assert_postfix() + ); + self + } + + pub(crate) fn assert_has_monomorph(self, serialized_monomorph: &str) -> Self { + let (has_monomorph, serialized) = has_monomorph(self.ctx, self.def.this.upcast(), serialized_monomorph); + assert!( + has_monomorph, "[{}] Expected to find monomorph {}, but got {} for {}", + self.ctx.test_name, serialized_monomorph, serialized, self.assert_postfix() + ); + self + } + + pub(crate) fn assert_postfix(&self) -> String { + let mut v = String::new(); + v.push_str("Enum{ name: "); + v.push_str(&String::from_utf8_lossy(&self.def.identifier.value)); + v.push_str(", variants: ["); + for (variant_idx, variant) in self.def.variants.iter().enumerate() { + if variant_idx != 0 { v.push_str(", "); } + v.push_str(&String::from_utf8_lossy(&variant.identifier.value)); + } + v.push_str("] }"); + v + } +} + pub(crate) struct FunctionTester<'a> { ctx: TestCtx<'a>, def: &'a Function, @@ -645,6 +695,43 @@ impl<'a> ErrorTester<'a> { // Generic utilities //------------------------------------------------------------------------------ +fn has_equal_num_monomorphs<'a>(ctx: TestCtx<'a>, num: usize, definition_id: DefinitionId) -> (bool, usize) { + let type_def = ctx.types.get_base_definition(&definition_id).unwrap(); + let num_on_type = type_def.monomorphs.len(); + + (num_on_type == num, num_on_type) +} + +fn has_monomorph<'a>(ctx: TestCtx<'a>, definition_id: DefinitionId, serialized_monomorph: &str) -> (bool, String) { + let type_def = ctx.types.get_base_definition(&definition_id).unwrap(); + + let mut full_buffer = String::new(); + let mut has_match = false; + full_buffer.push('['); + for (monomorph_idx, monomorph) in type_def.monomorphs.iter().enumerate() { + let mut buffer = String::new(); + for (element_idx, monomorph_element) in monomorph.iter().enumerate() { + if element_idx != 0 { buffer.push(';'); } + serialize_concrete_type(&mut buffer, ctx.heap, definition_id, monomorph_element); + } + + if buffer == serialized_monomorph { + // Found an exact match + has_match = true; + } + + if monomorph_idx != 0 { + full_buffer.push_str(", "); + } + full_buffer.push('"'); + full_buffer.push_str(&buffer); + full_buffer.push('"'); + } + full_buffer.push(']'); + + (has_match, full_buffer) +} + fn serialize_parser_type(buffer: &mut String, heap: &Heap, id: ParserTypeId) { use ParserTypeVariant as PTV; @@ -740,12 +827,14 @@ fn serialize_concrete_type(buffer: &mut String, heap: &Heap, def: DefinitionId, CTP::Instance(definition_id, num_sub) => { let definition_name = heap[*definition_id].identifier(); buffer.push_str(&String::from_utf8_lossy(&definition_name.value)); - buffer.push('<'); - for sub_idx in 0..*num_sub { - if sub_idx != 0 { buffer.push(','); } - idx = serialize_recursive(buffer, heap, poly_vars, concrete, idx + 1); + if *num_sub != 0 { + buffer.push('<'); + for sub_idx in 0..*num_sub { + if sub_idx != 0 { buffer.push(','); } + idx = serialize_recursive(buffer, heap, poly_vars, concrete, idx + 1); + } + buffer.push('>'); } - buffer.push('>'); idx += 1; } }