From 1811fe09856a30864bc60e509fdb948387248a12 2021-03-19 19:31:26 From: MH Date: 2021-03-19 19:31:26 Subject: [PATCH] more progress on function call inference --- diff --git a/src/protocol/ast.rs b/src/protocol/ast.rs index e293677ea5b26a0c226fbca3ba5572bcf4810408..2d81d1aabb71209306ea0e326894ffd1af6378e8 100644 --- a/src/protocol/ast.rs +++ b/src/protocol/ast.rs @@ -2271,6 +2271,7 @@ pub struct CallExpression { pub position: InputPosition, pub method: Method, pub arguments: Vec, + pub poly_args: Vec, // Phase 2: linker pub parent: ExpressionParent, } diff --git a/src/protocol/lexer.rs b/src/protocol/lexer.rs index d8b00c04276d90cede9bb56f99a6945cb5db1872..e87532ab007d89541ba5dcc0394d4f9a14a95d89 100644 --- a/src/protocol/lexer.rs +++ b/src/protocol/lexer.rs @@ -347,6 +347,7 @@ impl Lexer<'_> { let mut ns_ident = self.consume_ident()?; let mut num_namespaces = 1; while self.has_string(b"::") { + self.consume_string(b"::"); if num_namespaces >= MAX_NAMESPACES { return Err(self.error_at_pos("Too many namespaces in identifier")); } @@ -362,6 +363,20 @@ impl Lexer<'_> { num_namespaces, }) } + fn consume_namespaced_identifier_spilled(&mut self) -> Result<(), ParseError2> { + // TODO: @performance + if self.has_reserved() { + return Err(self.error_at_pos("Encountered reserved keyword")); + } + + self.consume_ident()?; + while self.has_string(b"::") { + self.consume_string(b"::")?; + self.consume_ident()?; + } + + Ok(()) + } // Types and type annotations @@ -510,65 +525,43 @@ impl Lexer<'_> { Ok(parser_type_id) } - /// Consumes things that look like types. If everything seems to look like - /// a type then `true` will be returned and the input position will be - /// placed after the type. If it doesn't appear to be a type then `false` - /// will be returned. - /// TODO: @cleanup, this is not particularly pretty or robust, methinks - fn maybe_consume_type_spilled(&mut self) -> bool { - // Spilling polymorphic args. Don't care about the input position - fn maybe_consume_polymorphic_args(v: &mut Lexer) -> bool { - if v.consume_whitespace(false).is_err() { return false; } - if let Some(b'<') = v.source.next() { - v.source.consume(); - if v.consume_whitespace(false).is_err() { return false; } - loop { - if !maybe_consume_type_inner(v) { return false; } - if v.consume_whitespace(false).is_err() { return false; } - let has_comma = v.source.next() == Some(b','); - if has_comma { - v.source.consume(); - if v.consume_whitespace(false).is_err() { return false; } - } - if let Some(b'>') = v.source.next() { - v.source.consume(); - break; - } else if !has_comma { - return false; - } - } - } - return true; + /// Attempts to consume a type without returning it. If it doesn't encounter + /// a well-formed type, then the input position is left at a "random" + /// position. + fn maybe_consume_type_spilled_without_pos_recovery(&mut self) -> bool { + // Consume type identifier + if self.has_type_keyword() { + self.consume_any_chars(); + } else { + let ident = self.consume_namespaced_identifier(); + if ident.is_err() { return false; } } - // Inner recursive type parser. This method simply advances the lexer - // and does not store the backup position in case parsing fails - fn maybe_consume_type_inner(v: &mut Lexer) -> bool { - // Consume type identifier and optional polymorphic args - if v.has_type_keyword() { - v.consume_any_chars() - } else { - let ident = v.consume_namespaced_identifier(); - if ident.is_err() { return false } - } - - if !maybe_consume_polymorphic_args(v) { return false; } - - // Check if wrapped in array - if v.consume_whitespace(false).is_err() { return false } - while let Some(b'[') = v.source.next() { - v.source.consume(); - if v.consume_whitespace(false).is_err() { return false; } - if Some(b']') != v.source.next() { return false; } - v.source.consume(); - } + // Consume any polymorphic arguments that follow the type identifier + if self.consume_whitespace(false).is_err() { return false; } + if !self.maybe_consume_poly_args_spilled_without_pos_recovery() { return false; } - return true; + // Consume any array specifiers. Make sure we always leave the input + // position at the end of the last array specifier if we do find a + // valid type + let mut backup_pos = self.source.pos(); + if self.consume_whitespace(false).is_err() { return false; } + while let Some(b'[') = self.source.next() { + self.source.consume(); + if self.consume_whitespace(false).is_err() { return false; } + if self.source.next() != Some(b']') { return false; } + self.source.consume(); + backup_pos = self.source.pos(); + if self.consume_whitespace(false).is_err() { return false; } } + self.source.seek(backup_pos); + return true; + } + + fn maybe_consume_type_spilled(&mut self) -> bool { let backup_pos = self.source.pos(); - if !maybe_consume_type_inner(self) { - // Not a type + if !self.maybe_consume_type_spilled_without_pos_recovery() { self.source.seek(backup_pos); return false; } @@ -576,6 +569,33 @@ impl Lexer<'_> { return true; } + /// Attempts to consume polymorphic arguments without returning them. If it + /// doesn't encounter well-formed polymorphic arguments, then the input + /// position is left at a "random" position. + fn maybe_consume_poly_args_spilled_without_pos_recovery(&mut self) -> bool { + if let Some(b'<') = self.source.next() { + self.source.consume(); + if self.consume_whitespace(false).is_err() { return false; } + loop { + if !self.maybe_consume_type_spilled_without_pos_recovery() { return false; } + if self.consume_whitespace(false).is_err() { return false; } + let has_comma = self.source.next() == Some(b','); + if has_comma { + self.source.consume(); + if self.consume_whitespace(false).is_err() { return false; } + } + if let Some(b'>') = self.source.next() { + self.source.consume(); + break; + } else if !has_comma { + return false; + } + } + } + + return true; + } + /// Consumes polymorphic arguments and its delimiters if specified. The /// input position may be at whitespace. If polyargs are present then the /// whitespace and the args are consumed and the input position will be @@ -1383,28 +1403,31 @@ impl Lexer<'_> { })) } fn has_call_expression(&mut self) -> bool { - /* We prevent ambiguity with variables, by looking ahead - the identifier to see if we can find an opening - parenthesis: this signals a call expression. */ + // We need to prevent ambiguity with various operators (because we may + // be specifying polymorphic variables) and variables. if self.has_builtin_keyword() { return true; } + let backup_pos = self.source.pos(); let mut result = false; - match self.consume_identifier_spilled() { - Ok(_) => match self.consume_whitespace(false) { - Ok(_) => { - result = self.has_string(b"("); - } - Err(_) => {} - }, - Err(_) => {} + + if self.consume_namespaced_identifier_spilled().is_ok() && + self.consume_whitespace(false).is_ok() && + self.maybe_consume_poly_args_spilled_without_pos_recovery().is_ok() && + self.consume_whitespace(false).is_ok() && + self.source.next() == Some(b'(') { + // Seems like we have a function call or an enum literal + result = true; } + self.source.seek(backup_pos); return result; } fn consume_call_expression(&mut self, h: &mut Heap) -> Result { let position = self.source.pos(); + + // Consume method identifier let method; if self.has_keyword(b"get") { self.consume_keyword(b"get")?; @@ -1422,11 +1445,18 @@ impl Lexer<'_> { definition: None }) } + + // Consume polymorphic arguments + self.consume_whitespace(false)?; + let poly_args = self.consume_polymorphic_args(h, true)?; + + // Consume arguments to call self.consume_whitespace(false)?; let mut arguments = Vec::new(); self.consume_string(b"(")?; self.consume_whitespace(false)?; if !self.has_string(b")") { + // TODO: allow trailing comma while self.source.next().is_some() { arguments.push(self.consume_expression(h)?); self.consume_whitespace(false)?; @@ -1443,6 +1473,7 @@ impl Lexer<'_> { position, method, arguments, + poly_args, parent: ExpressionParent::None, })) } @@ -1565,9 +1596,9 @@ impl Lexer<'_> { } let backup_pos = self.source.pos(); let mut result = false; - if self.maybe_consume_type_spilled() { + if self.maybe_consume_type_spilled_without_pos_recovery() { // We seem to have a valid type, do we now have an identifier? - if self.consume_whitespace(false).is_ok() { + if self.consume_whitespace(true).is_ok() { result = self.has_identifier(); } } diff --git a/src/protocol/parser/type_resolver.rs b/src/protocol/parser/type_resolver.rs index fe7523b7c9cef3a86a6e6d24f73a0b1e2a783771..a930607092b1517cd117dca6d401a1d0bade0ea7 100644 --- a/src/protocol/parser/type_resolver.rs +++ b/src/protocol/parser/type_resolver.rs @@ -11,10 +11,14 @@ use super::visitor::{ Visitor2, VisitorResult }; +use std::collections::hash_map::Entry; const BOOL_TEMPLATE: [InferenceTypePart; 1] = [ InferenceTypePart::Bool ]; const NUMBERLIKE_TEMPLATE: [InferenceTypePart; 1] = [ InferenceTypePart::NumberLike ]; +const INTEGERLIKE_TEMPLATE: [InferenceTypePart; 1] = [ InferenceTypePart::IntegerLike ]; +const ARRAY_TEMPLATE: [InferenceTypePart; 2] = [ InferenceTypePart::Array, InferenceTypePart::Unknown ]; const ARRAYLIKE_TEMPLATE: [InferenceTypePart; 2] = [ InferenceTypePart::ArrayLike, InferenceTypePart::Unknown ]; +const PORTLIKE_TEMPLATE: [InferenceTypePart; 2] = [ InferenceTypePart::PortLike, InferenceTypePart::Unknown ]; /// TODO: @performance Turn into PartialOrd+Ord to simplify checks #[derive(Debug, Clone, Eq, PartialEq)] @@ -30,6 +34,7 @@ pub(crate) enum InferenceTypePart { NumberLike, // any kind of integer/float IntegerLike, // any kind of integer ArrayLike, // array or slice. Note that this must have a subtype + PortLike, // input or output port // Special types that cannot be instantiated by the user Void, // For builtin functions that do not return anything // Concrete types without subtypes @@ -59,7 +64,7 @@ impl InferenceTypePart { fn is_concrete(&self) -> bool { use InferenceTypePart as ITP; match self { - ITP::Unknown | ITP::NumberLike | ITP::IntegerLike | ITP::ArrayLike => false, + ITP::Unknown | ITP::NumberLike | ITP::IntegerLike | ITP::ArrayLike | ITP::PortLike => false, _ => true } } @@ -89,6 +94,14 @@ impl InferenceTypePart { } } + fn is_concrete_port(&self) -> bool { + use InferenceTypePart as ITP; + match self { + ITP::Input | ITP::Output => true, + _ => false, + } + } + /// Returns the change in "iteration depth" when traversing this particular /// part. The iteration depth is used to traverse the tree in a linear /// fashion. It is basically `number_of_subtypes - 1` @@ -102,7 +115,7 @@ impl InferenceTypePart { -1 }, ITP::Marker(_) | ITP::ArrayLike | ITP::Array | ITP::Slice | - ITP::Input | ITP::Output => { + ITP::PortLike | ITP::Input | ITP::Output => { // One subtype, so do not modify depth 0 }, @@ -113,6 +126,28 @@ impl InferenceTypePart { } } +impl From for InferenceTypePart { + fn from(v: ConcreteTypeVariant) -> InferenceTypePart { + use ConcreteTypeVariant as CTV; + use InferenceTypePart as ITP; + + match v { + CTV::Message => ITP::Message, + CTV::Bool => ITP::Bool, + CTV::Byte => ITP::Byte, + CTV::Short => ITP::Short, + CTV::Int => ITP::Int, + CTV::Long => ITP::Long, + CTV::String => ITP::String, + CTV::Array => ITP::Array, + CTV::Slice => ITP::Slice, + CTV::Input => ITP::Input, + CTV::Output => ITP::Output, + CTV::Instance(id, num) => ITP::Instance(id, num), + } + } +} + struct InferenceType { has_marker: bool, is_done: bool, @@ -122,7 +157,7 @@ struct InferenceType { impl InferenceType { fn new(has_marker: bool, is_done: bool, parts: Vec) -> Self { if cfg!(debug_assertions) { - debug_assert(!parts.is_empty()); + debug_assert!(!parts.is_empty()); if !has_marker { debug_assert!(parts.iter().all(|v| !v.is_marker())); } @@ -135,7 +170,7 @@ impl InferenceType { // TODO: @performance, might all be done inline in the type inference methods fn recompute_is_done(&mut self) { - self.done = self.parts.iter().all(|v| v.is_concrete()); + self.is_done = self.parts.iter().all(|v| v.is_concrete()); } /// Checks if type is, or may be inferred as, a number @@ -181,6 +216,10 @@ impl InferenceType { } } + fn marker_iter(&self) -> InferenceTypeMarkerIter { + InferenceTypeMarkerIter::new(&self.parts) + } + /// Given that the `parts` are a depth-first serialized tree of types, this /// function finds the subtree anchored at a specific node. The returned /// index is exclusive. @@ -231,9 +270,11 @@ impl InferenceType { } // Inference of a somewhat-specified type - if (*to_infer_part == ITP::IntegerLike && template_part.is_concrete_int()) || - (*to_infer_part == ITP::NumberLike && template_part.is_concrete_number())|| - (*to_infer_part == ITP::ArrayLike && template_part.is_concrete_array_or_slice()) + if (*to_infer_part == ITP::IntegerLike && template_part.is_concrete_integer()) || + (*to_infer_part == ITP::NumberLike && template_part.is_concrete_number()) || + (*to_infer_part == ITP::NumberLike && *template_part == ITP::IntegerLike) || + (*to_infer_part == ITP::ArrayLike && template_part.is_concrete_array_or_slice()) || + (*to_infer_part == ITP::PortLike && template_part.is_concrete_port()) { let depth_change = to_infer_part.depth_change(); debug_assert_eq!(depth_change, template_part.depth_change()); @@ -341,11 +382,11 @@ impl InferenceType { while depth > 0 { let to_infer_part = &to_infer.parts[to_infer_idx]; - let template_part = &template.parts[template_idx]; + let template_part = &template[template_idx]; - if part_a == part_b { + if to_infer_part == template_part { depth += to_infer_part.depth_change(); - debug_assert!(depth, template_part.depth_change()); + debug_assert_eq!(depth, template_part.depth_change()); to_infer_idx += 1; template_idx += 1; continue; @@ -362,6 +403,24 @@ impl InferenceType { continue; } + // The template might contain partially known types, so check for + // these and allow them + if *template_part == ITP::Unknown { + to_infer_idx = Self::find_subtree_end_idx(&to_infer.parts, to_infer_idx); + template_idx += 1; + continue; + } + + if (*template_part == ITP::NumberLike && (*to_infer_part == ITP::IntegerLike || to_infer_part.is_concrete_number())) || + (*template_part == ITP::IntegerLike && to_infer_part.is_concrete_integer()) || + (*template_part == ITP::ArrayLike && to_infer_part.is_concrete_array_or_slice()) || + (*template_part == ITP::PortLike && (*to_infer_part == ITP::PortLike || to_infer_part.is_concrete_port())) + { + to_infer_idx += 1; + template_idx += 1; + continue; + } + return SingleInferenceResult::Incompatible } @@ -376,48 +435,56 @@ impl InferenceType { /// Returns a human-readable version of the type. Only use for debugging /// or returning errors (since it allocates a string). fn display_name(&self, heap: &Heap) -> String { - use InferredPart as IP; + use InferenceTypePart as ITP; fn write_recursive(v: &mut String, t: &InferenceType, h: &Heap, idx: &mut usize) { match &t.parts[*idx] { - IP::Unknown => v.push_str("?"), - IP::Void => v.push_str("void"), - IP::IntegerLike => v.push_str("int?"), - IP::Message => v.push_str("msg"), - IP::Bool => v.push_str("bool"), - IP::Byte => v.push_str("byte"), - IP::Short => v.push_str("short"), - IP::Int => v.push_str("int"), - IP::Long => v.push_str("long"), - IP::String => v.push_str("str"), - IP::ArrayLike => { + ITP::Marker(_) => {}, + ITP::Unknown => v.push_str("?"), + ITP::NumberLike => v.push_str("num?"), + ITP::IntegerLike => v.push_str("int?"), + ITP::ArrayLike => { *idx += 1; write_recursive(v, t, h, idx); v.push_str("[?]"); + }, + ITP::PortLike => { + *idx += 1; + v.push_str("port?<"); + write_recursive(v, t, h, idx); + v.push('>'); } - IP::Array => { + ITP::Void => v.push_str("void"), + ITP::Message => v.push_str("msg"), + ITP::Bool => v.push_str("bool"), + ITP::Byte => v.push_str("byte"), + ITP::Short => v.push_str("short"), + ITP::Int => v.push_str("int"), + ITP::Long => v.push_str("long"), + ITP::String => v.push_str("str"), + ITP::Array => { *idx += 1; write_recursive(v, t, h, idx); v.push_str("[]"); }, - IP::Slice => { + ITP::Slice => { *idx += 1; write_recursive(v, t, h, idx); v.push_str("[..]") }, - IP::Input => { + ITP::Input => { *idx += 1; v.push_str("in<"); write_recursive(v, t, h, idx); v.push('>'); }, - IP::Output => { + ITP::Output => { *idx += 1; v.push_str("out<"); write_recursive(v, t, h, idx); v.push('>'); }, - IP::Instance(definition_id, num_sub) => { + ITP::Instance(definition_id, num_sub) => { let definition = &h[*definition_id]; v.push_str(&String::from_utf8_lossy(&definition.identifier().value)); if *num_sub > 0 { @@ -450,6 +517,12 @@ struct InferenceTypeMarkerIter<'a> { idx: usize, } +impl<'a> InferenceTypeMarkerIter<'a> { + fn new(parts: &'a [InferenceTypePart]) -> Self { + Self{ parts, idx: 0 } + } +} + impl<'a> Iterator for InferenceTypeMarkerIter<'a> { type Item = (usize, &'a [InferenceTypePart]); @@ -473,19 +546,6 @@ impl<'a> Iterator for InferenceTypeMarkerIter<'a> { } } -/// Extra data needed to fully resolve polymorphic types. Each argument contains -/// "markers" with an index corresponding to the polymorphic variable. Hence if -/// we advance any of the inference types with markers then we need to compare -/// them against the polymorph type. If the polymorph type is then progressed -/// then we need to apply that to all arguments that contain that polymorphic -/// type. -struct PolymorphInferenceType { - definition: DefinitionId, - poly_vars: Vec, - arguments: Vec, - return_type: InferenceType, -} - #[derive(PartialEq, Eq)] enum DualInferenceResult { Neither, // neither argument is clarified @@ -553,9 +613,19 @@ pub(crate) struct TypeResolvingVisitor { // specify these types until we're stuck or we've fully determined the type. infer_types: HashMap, expr_types: HashMap, + extra_data: HashMap, expr_queued: HashSet, } +// TODO: @rename used for calls and struct literals, maybe union literals? +struct ExtraData { + /// Progression of polymorphic variables (if any) + poly_vars: Vec, + /// Progression of types of call arguments or struct members + embedded: Vec, + returned: InferenceType, +} + impl TypeResolvingVisitor { pub(crate) fn new() -> Self { TypeResolvingVisitor{ @@ -565,6 +635,7 @@ impl TypeResolvingVisitor { polyvars: Vec::new(), infer_types: HashMap::new(), expr_types: HashMap::new(), + extra_data: HashMap::new(), expr_queued: HashSet::new(), } } @@ -591,8 +662,8 @@ impl Visitor2 for TypeResolvingVisitor { for param_id in comp_def.parameters.clone() { let param = &ctx.heap[param_id]; - let infer_type = self.determine_inference_type_from_parser_type(ctx, param.parser_type); - debug_assert!(infer_type.done, "expected component arguments to be concrete types"); + let infer_type = self.determine_inference_type_from_parser_type(ctx, param.parser_type, true); + debug_assert!(infer_type.is_done, "expected component arguments to be concrete types"); self.infer_types.insert(param_id.upcast(), infer_type); } @@ -609,8 +680,8 @@ impl Visitor2 for TypeResolvingVisitor { for param_id in func_def.parameters.clone() { let param = &ctx.heap[param_id]; - let infer_type = self.determine_inference_type_from_parser_type(ctx, param.parser_type); - debug_assert!(infer_type.done, "expected function arguments to be concrete types"); + let infer_type = self.determine_inference_type_from_parser_type(ctx, param.parser_type, true); + debug_assert!(infer_type.is_done, "expected function arguments to be concrete types"); self.infer_types.insert(param_id.upcast(), infer_type); } @@ -635,7 +706,7 @@ impl Visitor2 for TypeResolvingVisitor { let memory_stmt = &ctx.heap[id]; let local = &ctx.heap[memory_stmt.variable]; - let infer_type = self.determine_inference_type_from_parser_type(ctx, local.parser_type); + let infer_type = self.determine_inference_type_from_parser_type(ctx, local.parser_type, true); self.infer_types.insert(memory_stmt.variable.upcast(), infer_type); let expr_id = memory_stmt.initial; @@ -648,11 +719,11 @@ impl Visitor2 for TypeResolvingVisitor { let channel_stmt = &ctx.heap[id]; let from_local = &ctx.heap[channel_stmt.from]; - let from_infer_type = self.determine_inference_type_from_parser_type(ctx, from_local.parser_type); + let from_infer_type = self.determine_inference_type_from_parser_type(ctx, from_local.parser_type, true); self.infer_types.insert(from_local.this.upcast(), from_infer_type); let to_local = &ctx.heap[channel_stmt.to]; - let to_infer_type = self.determine_inference_type_from_parser_type(ctx, to_local.parser_type); + let to_infer_type = self.determine_inference_type_from_parser_type(ctx, to_local.parser_type, true); self.infer_types.insert(to_local.this.upcast(), to_infer_type); Ok(()) @@ -742,7 +813,7 @@ impl Visitor2 for TypeResolvingVisitor { fn visit_assignment_expr(&mut self, ctx: &mut Ctx, id: AssignmentExpressionId) -> VisitorResult { let upcast_id = id.upcast(); - self.insert_initial_expr_inference_type(ctx, upcast_id); + self.insert_initial_expr_inference_type(ctx, upcast_id)?; let assign_expr = &ctx.heap[id]; let left_expr_id = assign_expr.left; @@ -756,7 +827,7 @@ impl Visitor2 for TypeResolvingVisitor { fn visit_conditional_expr(&mut self, ctx: &mut Ctx, id: ConditionalExpressionId) -> VisitorResult { let upcast_id = id.upcast(); - self.insert_initial_expr_inference_type(ctx, upcast_id); + self.insert_initial_expr_inference_type(ctx, upcast_id)?; let conditional_expr = &ctx.heap[id]; let test_expr_id = conditional_expr.test; @@ -773,7 +844,7 @@ impl Visitor2 for TypeResolvingVisitor { fn visit_binary_expr(&mut self, ctx: &mut Ctx, id: BinaryExpressionId) -> VisitorResult { let upcast_id = id.upcast(); - self.insert_initial_expr_inference_type(ctx, upcast_id); + self.insert_initial_expr_inference_type(ctx, upcast_id)?; let binary_expr = &ctx.heap[id]; let lhs_expr_id = binary_expr.left; @@ -787,7 +858,7 @@ impl Visitor2 for TypeResolvingVisitor { fn visit_unary_expr(&mut self, ctx: &mut Ctx, id: UnaryExpressionId) -> VisitorResult { let upcast_id = id.upcast(); - self.insert_initial_expr_inference_type(ctx, upcast_id); + self.insert_initial_expr_inference_type(ctx, upcast_id)?; let unary_expr = &ctx.heap[id]; let arg_expr_id = unary_expr.expression; @@ -799,10 +870,11 @@ impl Visitor2 for TypeResolvingVisitor { fn visit_call_expr(&mut self, ctx: &mut Ctx, id: CallExpressionId) -> VisitorResult { let upcast_id = id.upcast(); - self.insert_initial_expr_inference_type(ctx, upcast_id); + self.insert_initial_expr_inference_type(ctx, upcast_id)?; + self.insert_initial_call_polymorph_data(ctx, id); - let call_expr = &ctx.heap[id]; // TODO: @performance + let call_expr = &ctx.heap[id]; for arg_expr_id in call_expr.arguments.clone() { self.visit_expr(ctx, arg_expr_id)?; } @@ -811,24 +883,6 @@ impl Visitor2 for TypeResolvingVisitor { } } -// TODO: @cleanup Decide to use this where appropriate or to make templates for -// everything -enum TypeClass { - Numeric, // int and float - Integer, // only ints - Boolean, // only boolean -} - -impl std::fmt::Display for TypeClass { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "{}", match self { - TypeClass::Numeric => "numeric", - TypeClass::Integer => "integer", - TypeClass::Boolean => "boolean", - }) - } -} - macro_rules! debug_assert_expr_ids_unique_and_known { // Base case for a single expression ID ($resolver:ident, $id:ident) => { @@ -862,31 +916,26 @@ impl TypeResolvingVisitor { use AssignmentOperator as AO; // TODO: Assignable check - let (type_class, arg1_expr_id, arg2_expr_id) = { - let expr = &ctx.heap[id]; - let type_class = match expr.operation { - AO::Set => - None, - AO::Multiplied | AO::Divided | AO::Added | AO::Subtracted => - Some(TypeClass::Numeric), - AO::Remained | AO::ShiftedLeft | AO::ShiftedRight | - AO::BitwiseAnded | AO::BitwiseXored | AO::BitwiseOred => - Some(TypeClass::Integer), - }; - - (type_class, expr.left, expr.right) + let upcast_id = id.upcast(); + let expr = &ctx.heap[id]; + let arg1_expr_id = expr.left; + let arg2_expr_id = expr.right; + + let progress_base = match expr.operation { + AO::Set => + false, + AO::Multiplied | AO::Divided | AO::Added | AO::Subtracted => + self.apply_forced_constraint(ctx, upcast_id, &NUMBERLIKE_TEMPLATE)?, + AO::Remained | AO::ShiftedLeft | AO::ShiftedRight | + AO::BitwiseAnded | AO::BitwiseXored | AO::BitwiseOred => + self.apply_forced_constraint(ctx, upcast_id, &INTEGERLIKE_TEMPLATE)?, }; - let upcast_id = id.upcast(); let (progress_expr, progress_arg1, progress_arg2) = self.apply_equal3_constraint( - ctx, upcast_id, arg1_expr_id, arg2_expr_id + ctx, upcast_id, arg1_expr_id, arg2_expr_id, 0 )?; - if let Some(type_class) = type_class { - self.expr_type_is_of_type_class(ctx, id.upcast(), type_class)? - } - - if progress_expr { self.queue_expr_parent(ctx, upcast_id); } + if progress_base || progress_expr { self.queue_expr_parent(ctx, upcast_id); } if progress_arg1 { self.queue_expr(arg1_expr_id); } if progress_arg2 { self.queue_expr(arg2_expr_id); } @@ -901,7 +950,7 @@ impl TypeResolvingVisitor { let arg2_expr_id = expr.false_expression; let (progress_expr, progress_arg1, progress_arg2) = self.apply_equal3_constraint( - ctx, upcast_id, arg1_expr_id, arg2_expr_id + ctx, upcast_id, arg1_expr_id, arg2_expr_id, 0 )?; if progress_expr { self.queue_expr_parent(ctx, upcast_id); } @@ -923,35 +972,49 @@ impl TypeResolvingVisitor { let (progress_expr, progress_arg1, progress_arg2) = match expr.operation { BO::Concatenate => { - // Arguments may be arrays/slices with the same subtype. Output - // is always an array with that subtype - (false, false, false) + // Arguments may be arrays/slices, output is always an array + let progress_expr = self.apply_forced_constraint(ctx, upcast_id, &ARRAY_TEMPLATE)?; + let progress_arg1 = self.apply_forced_constraint(ctx, arg1_id, &ARRAYLIKE_TEMPLATE)?; + let progress_arg2 = self.apply_forced_constraint(ctx, arg2_id, &ARRAYLIKE_TEMPLATE)?; + + // If they're all arraylike, then we want the subtype to match + let (subtype_expr, subtype_arg1, subtype_arg2) = + self.apply_equal3_constraint(ctx, upcast_id, arg1_id, arg2_id, 1)?; + + (progress_expr || subtype_expr, progress_arg1 || subtype_arg1, progress_arg2 || subtype_arg2) }, BO::LogicalOr | BO::LogicalAnd => { // Forced boolean on all - let progress_expr = self.apply_forced_constraint(ctx, upcast_id, BOOL_TEMPLATE.as_slice())?; + let progress_expr = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?; let progress_arg1 = self.apply_forced_constraint(ctx, arg1_id, &BOOL_TEMPLATE)?; let progress_arg2 = self.apply_forced_constraint(ctx, arg2_id, &BOOL_TEMPLATE)?; (progress_expr, progress_arg1, progress_arg2) }, BO::BitwiseOr | BO::BitwiseXor | BO::BitwiseAnd | BO::Remainder | BO::ShiftLeft | BO::ShiftRight => { - let result = self.apply_equal3_constraint(ctx, upcast_id, arg1_id, arg2_id)?; - self.expr_type_is_of_type_class(ctx, upcast_id, TypeClass::Integer)?; - result + // All equal of integer type + let progress_base = self.apply_forced_constraint(ctx, upcast_id, &INTEGERLIKE_TEMPLATE)?; + let (progress_expr, progress_arg1, progress_arg2) = + self.apply_equal3_constraint(ctx, upcast_id, arg1_id, arg2_id, 0)?; + + (progress_base || progress_expr, progress_base || progress_arg1, progress_base || progress_arg2) }, BO::Equality | BO::Inequality | BO::LessThan | BO::GreaterThan | BO::LessThanEqual | BO::GreaterThanEqual => { // Equal2 on args, forced boolean output let progress_expr = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?; - let (progress_arg1, progress_arg2) = self.apply_equal2_constraint(ctx, upcast_id, arg1_id, arg2_id)?; - self.expr_type_is_of_type_class(ctx, arg1_id, TypeClass::Numeric)?; + let progress_arg_base = self.apply_forced_constraint(ctx, arg1_id, &NUMBERLIKE_TEMPLATE)?; + let (progress_arg1, progress_arg2) = + self.apply_equal2_constraint(ctx, upcast_id, arg1_id, 0, arg2_id, 0)?; - (progress_expr, progress_arg1, progress_arg2) + (progress_expr, progress_arg_base || progress_arg1, progress_arg_base || progress_arg2) }, BO::Add | BO::Subtract | BO::Multiply | BO::Divide => { - let result = self.apply_equal3_constraint(ctx, upcast_id, arg1_id, arg2_id)?; - self.expr_type_is_of_type_class(ctx, upcast_id, TypeClass::Numeric)?; - result + // All equal of number type + let progress_base = self.apply_forced_constraint(ctx, upcast_id, &NUMBERLIKE_TEMPLATE)?; + let (progress_expr, progress_arg1, progress_arg2) = + self.apply_equal3_constraint(ctx, upcast_id, arg1_id, arg2_id, 0)?; + + (progress_base || progress_expr, progress_base || progress_arg1, progress_base || progress_arg2) }, }; @@ -972,15 +1035,19 @@ impl TypeResolvingVisitor { let (progress_expr, progress_arg) = match expr.operation { UO::Positive | UO::Negative => { // Equal types of numeric class - let progress = self.apply_equal2_constraint(ctx, upcast_id, upcast_id, arg_id)?; - self.expr_type_is_of_type_class(ctx, upcast_id, TypeClass::Numeric)?; - progress + let progress_base = self.apply_forced_constraint(ctx, upcast_id, &NUMBERLIKE_TEMPLATE)?; + let (progress_expr, progress_arg) = + self.apply_equal2_constraint(ctx, upcast_id, upcast_id, 0, arg_id, 0)?; + + (progress_base || progress_expr, progress_base || progress_arg) }, UO::BitwiseNot | UO::PreIncrement | UO::PreDecrement | UO::PostIncrement | UO::PostDecrement => { // Equal types of integer class - let progress = self.apply_equal2_constraint(ctx, upcast_id, upcast_id, arg_id)?; - self.expr_type_is_of_type_class(ctx, upcast_id, TypeClass::Integer)?; - progress + let progress_base = self.apply_forced_constraint(ctx, upcast_id, &INTEGERLIKE_TEMPLATE)?; + let (progress_expr, progress_arg) = + self.apply_equal2_constraint(ctx, upcast_id, upcast_id, 0, arg_id, 0)?; + + (progress_base || progress_expr, progress_base || progress_arg) }, UO::LogicalNot => { // Both booleans @@ -997,25 +1064,86 @@ impl TypeResolvingVisitor { } fn progress_indexing_expr(&mut self, ctx: &mut Ctx, id: IndexingExpressionId) -> Result<(), ParseError2> { - // TODO: Indexable check let upcast_id = id.upcast(); let expr = &ctx.heap[id]; let subject_id = expr.subject; let index_id = expr.index; - let progress_subject = self.apply_forced_constraint(ctx, subject_id, &ARRAYLIKE_TEMPLATE)?; + // Make sure subject is arraylike and index is integerlike + let progress_subject_base = self.apply_forced_constraint(ctx, subject_id, &ARRAYLIKE_TEMPLATE)?; let progress_index = self.apply_forced_constraint(ctx, index_id, &INTEGERLIKE_TEMPLATE)?; - // TODO: Finish this + // Make sure if output is of T then subject is Array + let (progress_expr, progress_subject) = + self.apply_equal2_constraint(ctx, upcast_id, upcast_id, 0, subject_id, 1)?; + + if progress_expr { self.queue_expr_parent(ctx, upcast_id); } + if progress_subject_base || progress_subject { self.queue_expr(subject_id); } + if progress_index { self.queue_expr(index_id); } + Ok(()) } - fn progress_call_expr(&mut self, ctx: &mut Ctx, id: CallExpressionId) -> Result<(), ParseError2> { - let - upcast_id = id.upcast(); + fn progress_slicing_expr(&mut self, ctx: &mut Ctx, id: SlicingExpressionId) -> Result<(), ParseError2> { + let upcast_id = id.upcast(); let expr = &ctx.heap[id]; + let subject_id = expr.subject; + let from_id = expr.from_index; + let to_id = expr.to_index; + + // Make sure subject is arraylike and indices are of equal integerlike + let progress_subject_base = self.apply_forced_constraint(ctx, subject_id, &ARRAYLIKE_TEMPLATE)?; + let progress_idx_base = self.apply_forced_constraint(ctx, from_id, &INTEGERLIKE_TEMPLATE)?; + let (progress_from, progress_to) = self.apply_equal2_constraint(ctx, upcast_id, from_id, 0, to_id, 0)?; + + // Make sure if output is of T then subject is Array + let (progress_expr, progress_subject) = + self.apply_equal2_constraint(ctx, upcast_id, upcast_id, 0, subject_id, 1)?; + + if progress_expr { self.queue_expr_parent(ctx, upcast_id); } + if progress_subject_base || progress_subject { self.queue_expr(subject_id); } + if progress_idx_base || progress_from { self.queue_expr(from_id); } + if progress_idx_base || progress_to { self.queue_expr(to_id); } + + Ok(()) + } + fn progress_call_expr(&mut self, ctx: &mut Ctx, id: CallExpressionId) -> Result<(), ParseError2> { + let upcast_id = id.upcast(); + let expr = &ctx.heap[id]; + let extra = self.extra_data.get_mut(&upcast_id).unwrap(); + + // Check if we can make progress using the arguments and/or return types + // while keeping track of the polyvars we've extended + let mut poly_progress = HashSet::new(); + debug_assert_eq!(extra.embedded.len(), expr.arguments.len()); + for (arg_idx, arg_id) in expr.arguments.clone().into_iter().enumerate() { + let extra_type: *mut _ = &mut extra.embedded[arg_idx]; + let (progress_expr, progress_extra) = self.apply_arglike_equal2_constraint(ctx, arg_id, extra_type)?; + + if progress_expr { self.queue_expr(arg_id); } + if progress_extra { + unsafe { + // Try to advance each polymorphic variable + debug_assert!((*extra_type).has_marker); + let mut marker_iter = unsafe { (*extra_type).marker_iter() }; + for (marker_idx, section) in marker_iter { + let poly_type: *mut _ = &mut extra.poly_vars[marker_idx]; + match InferenceType::infer_subtree_for_single_type(&mut *poly_type, 0, section, 0) { + SingleInferenceResult::Unmodified => {}, + SingleInferenceResult::Modified => { + poly_progress.insert(marker_idx); + }, + SingleInferenceResult::Incompatible => { + todo!("Decent error message, and how?"); + } + } + } + } + } + } + Ok(()) } fn queue_expr_parent(&mut self, ctx: &Ctx, expr_id: ExpressionId) { @@ -1038,9 +1166,9 @@ impl TypeResolvingVisitor { debug_assert_expr_ids_unique_and_known!(self, expr_id); let expr_type = self.expr_types.get_mut(&expr_id).unwrap(); match InferenceType::infer_subtree_for_single_type(expr_type, 0, template, 0) { - InferenceTemplateResult::Modified => Ok(true), - InferenceTemplateResult::Unmodified => Ok(false), - InferenceTemplateResult::Incompatible => Err( + SingleInferenceResult::Modified => Ok(true), + SingleInferenceResult::Unmodified => Ok(false), + SingleInferenceResult::Incompatible => Err( self.construct_template_type_error(ctx, expr_id, template) ) } @@ -1051,13 +1179,18 @@ impl TypeResolvingVisitor { /// is successful then the composition of all types are made equal. /// The "parent" `expr_id` is provided to construct errors. fn apply_equal2_constraint( - &mut self, ctx: &mut Ctx, expr_id: ExpressionId, arg1_id: ExpressionId, arg2_id: ExpressionId + &mut self, ctx: &Ctx, expr_id: ExpressionId, + arg1_id: ExpressionId, arg1_start_idx: usize, + arg2_id: ExpressionId, arg2_start_idx: usize ) -> Result<(bool, bool), ParseError2> { debug_assert_expr_ids_unique_and_known!(self, arg1_id, arg2_id); let arg1_type: *mut _ = self.expr_types.get_mut(&arg1_id).unwrap(); let arg2_type: *mut _ = self.expr_types.get_mut(&arg2_id).unwrap(); - let infer_res = unsafe{ InferenceType::infer_subtrees_for_both_types(arg1_type, 0, arg2_type, 0) }; + let infer_res = unsafe{ InferenceType::infer_subtrees_for_both_types( + arg1_type, arg1_start_idx, + arg2_type, arg2_start_idx + ) }; if infer_res == DualInferenceResult::Incompatible { return Err(self.construct_arg_type_error(ctx, expr_id, arg1_id, arg2_id)); } @@ -1065,13 +1198,39 @@ impl TypeResolvingVisitor { Ok((infer_res.modified_lhs(), infer_res.modified_rhs())) } + // TODO: @cleanup Bit of a hack, but borrowing rules are really annoying here. Maybe not pack + // `ExtraData` together, but keep as a HashMap with very specific keys? e.g. ReturnType(ExprId) + fn apply_arglike_equal2_constraint( + &mut self, ctx: &Ctx, expr_id: ExpressionId, + direct_type: *mut InferenceType + ) -> Result<(bool, bool), ParseError2> { + let expr_type: *mut _ = self.expr_types.get_mut(&expr_id).unwrap(); + let infer_res = unsafe{ + InferenceType::infer_subtrees_for_both_types(expr_type, 0, direct_type, 0) + }; + if infer_res == DualInferenceResult::Incompatible { + let expr_type = unsafe{ &*expr_type }; + let direct_type = unsafe{ &*direct_type }; + return Err(ParseError2::new_error( + &ctx.module.source, ctx.heap[expr_id].position(), + &format!( + "Expected type '{}' but got '{}'", + direct_type.display_name(ctx.heap), expr_type.display_name(ctx.heap) + ) + )); + } + + Ok((infer_res.modified_lhs(), infer_res.modified_rhs())) + } + /// Applies a type constraint that expects all three provided types to be /// equal. In case we can make progress in inferring the types then we /// attempt to do so. If the call is successful then the composition of all /// types is made equal. fn apply_equal3_constraint( - &mut self, ctx: &mut Ctx, expr_id: ExpressionId, - arg1_id: ExpressionId, arg2_id: ExpressionId + &mut self, ctx: &Ctx, expr_id: ExpressionId, + arg1_id: ExpressionId, arg2_id: ExpressionId, + start_idx: usize ) -> Result<(bool, bool, bool), ParseError2> { // Safety: all expression IDs are always distinct, and we do not modify // the container @@ -1080,12 +1239,15 @@ impl TypeResolvingVisitor { let arg1_type: *mut _ = self.expr_types.get_mut(&arg1_id).unwrap(); let arg2_type: *mut _ = self.expr_types.get_mut(&arg2_id).unwrap(); - let expr_res = unsafe{ InferenceType::infer_subtrees_for_both_types(expr_type, 0, arg1_type, 0) }; + let expr_res = unsafe{ + InferenceType::infer_subtrees_for_both_types(expr_type, start_idx, arg1_type, start_idx) + }; if expr_res == DualInferenceResult::Incompatible { return Err(self.construct_expr_type_error(ctx, expr_id, arg1_id)); } - let args_res = unsafe{ InferenceType::infer_subtrees_for_both_types(arg1_type, 0, arg2_type, 0) }; + let args_res = unsafe{ + InferenceType::infer_subtrees_for_both_types(arg1_type, start_idx, arg2_type, start_idx) }; if args_res == DualInferenceResult::Incompatible { return Err(self.construct_arg_type_error(ctx, expr_id, arg1_id, arg2_id)); } @@ -1098,8 +1260,8 @@ impl TypeResolvingVisitor { if args_res.modified_lhs() { unsafe { - (*expr_type).parts.clear(); - (*expr_type).parts.extend((*arg2_type).parts.iter()); + (*expr_type).parts.drain(start_idx..); + (*expr_type).parts.extend_from_slice(&((*arg2_type).parts[start_idx..])); } progress_expr = true; progress_arg1 = true; @@ -1108,27 +1270,6 @@ impl TypeResolvingVisitor { Ok((progress_expr, progress_arg1, progress_arg2)) } - /// Applies a typeclass constraint: checks if the type is of a particular - /// class or not - fn expr_type_is_of_type_class( - &mut self, ctx: &mut Ctx, expr_id: ExpressionId, type_class: TypeClass - ) -> Result<(), ParseError2> { - debug_assert_expr_ids_unique_and_known!(self, expr_id); - let expr_type = self.expr_types.get(&expr_id).unwrap(); - - let is_ok = match type_class { - TypeClass::Numeric => expr_type.might_be_numeric(), - TypeClass::Integer => expr_type.might_be_integer(), - TypeClass::Boolean => expr_type.might_be_boolean(), - }; - - if is_ok { - Ok(()) - } else { - Err(self.construct_type_class_error(ctx, expr_id, type_class)) - } - } - /// Determines the `InferenceType` for the expression based on the /// expression parent. Note that if the parent is another expression, we do /// not take special action, instead we let parent expressions fix the type @@ -1137,11 +1278,9 @@ impl TypeResolvingVisitor { /// anything. fn insert_initial_expr_inference_type( &mut self, ctx: &mut Ctx, expr_id: ExpressionId - ) { - // TODO: @cleanup Concept of "parent expression" can be removed, the - // type inferer/checker can set this upon the initial pass + ) -> Result<(), ParseError2> { use ExpressionParent as EP; - if self.expr_types.contains_key(&expr_id) { return; } + use InferenceTypePart as ITP; let expr = &ctx.heap[expr_id]; let inference_type = match expr.parent() { @@ -1150,15 +1289,15 @@ impl TypeResolvingVisitor { unreachable!(), EP::Memory(_) | EP::ExpressionStmt(_) | EP::Expression(_, _) => // Determined during type inference - InferenceType::new(false, false, vec![InferredPart::Unknown]), + InferenceType::new(false, false, vec![ITP::Unknown]), EP::If(_) | EP::While(_) | EP::Assert(_) => // Must be a boolean - InferenceType::new(false, true, vec![InferredPart::Bool]), + InferenceType::new(false, true, vec![ITP::Bool]), EP::Return(_) => // Must match the return type of the function if let DefinitionType::Function(func_id) = self.definition_type { let return_parser_type_id = ctx.heap[func_id].return_type; - self.determine_inference_type_from_parser_type(ctx, return_parser_type_id) + self.determine_inference_type_from_parser_type(ctx, return_parser_type_id, true) } else { // Cannot happen: definition always set upon body traversal // and "return" calls in components are illegal. @@ -1167,60 +1306,177 @@ impl TypeResolvingVisitor { EP::New(_) => // Must be a component call, which we assign a "Void" return // type - InferenceType::new(false, true, vec![InferredPart::Void]), + InferenceType::new(false, true, vec![ITP::Void]), EP::Put(_, 0) => // TODO: Change put to be a builtin function // port of "put" call - InferenceType::new(false, false, vec![InferredPart::Output, InferredPart::Unknown]), + InferenceType::new(false, false, vec![ITP::Output, ITP::Unknown]), EP::Put(_, 1) => // TODO: Change put to be a builtin function // message of "put" call - InferenceType::new(false, true, vec![InferredPart::Message]), + InferenceType::new(false, true, vec![ITP::Message]), EP::Put(_, _) => unreachable!() }; - self.expr_types.insert(expr_id, inference_type); + match self.expr_types.entry(expr_id) { + Entry::Vacant(vacant) => { + vacant.insert(inference_type); + }, + Entry::Occupied(mut preexisting) => { + // We already have an entry, this happens if our parent fixed + // our type (e.g. we're used in a conditional expression's test) + // but we have a different type. + // TODO: Is this ever called? Seems like it can't + debug_assert!(false, "I am actually called, my ID is {}", expr_id.index); + let old_type = preexisting.get_mut(); + if let SingleInferenceResult::Incompatible = InferenceType::infer_subtree_for_single_type( + old_type, 0, &inference_type.parts, 0 + ) { + return Err(self.construct_expr_type_error(ctx, expr_id, expr_id)) + } + } + } + + Ok(()) + } + + fn insert_initial_call_polymorph_data( + &mut self, ctx: &mut Ctx, call_id: CallExpressionId + ) { + use InferenceTypePart as ITP; + + // Note: the polymorph variables may be partially specified and may + // contain references to the wrapping definition's (i.e. the proctype + // we are currently visiting) polymorphic arguments. + // + // The arguments of the call may refer to polymorphic variables in the + // definition of the function we're calling, not of the wrapping + // definition. We insert markers in these inferred types to be able to + // map them back and forth to the polymorphic arguments of the function + // we are calling. + let call = &ctx.heap[call_id]; + debug_assert!(!call.poly_args.is_empty()); + + // Handle the polymorphic variables themselves + let mut poly_vars = Vec::with_capacity(call.poly_args.len()); + for poly_arg_type_id in call.poly_args.clone() { // TODO: @performance + poly_vars.push(self.determine_inference_type_from_parser_type(ctx, poly_arg_type_id, true)); + } + + // Handle the arguments + // TODO: @cleanup: Maybe factor this out for reuse in the validator/linker, should also + // make the code slightly more robust. + let (embedded_types, return_type) = match &call.method { + Method::Create => { + // Not polymorphic + unreachable!("insert initial polymorph data for builtin 'create()' call") + }, + Method::Fires => { + // bool fires(PortLike arg) + ( + vec![InferenceType::new(true, false, vec![ITP::PortLike, ITP::Marker(0), ITP::Unknown])], + InferenceType::new(false, true, vec![ITP::Bool]) + ) + }, + Method::Get => { + // T get(input arg) + ( + vec![InferenceType::new(true, false, vec![ITP::Input, ITP::Marker(0), ITP::Unknown])], + InferenceType::new(true, false, vec![ITP::Marker(0), ITP::Unknown]) + ) + }, + Method::Symbolic(symbolic) => { + let definition = &ctx.heap[symbolic.definition.unwrap()]; + + match definition { + Definition::Component(definition) => { + let mut parameter_types = Vec::with_capacity(definition.parameters.len()); + for param_id in definition.parameters.clone() { + let param = &ctx.heap[param_id]; + let param_parser_type_id = param.parser_type; + parameter_types.push(self.determine_inference_type_from_parser_type(ctx, param_parser_type_id, false)); + } + + (parameter_types, InferenceType::new(false, true, vec![InferenceTypePart::Unknown])) + }, + Definition::Function(definition) => { + let mut parameter_types = Vec::with_capacity(definition.parameters.len()); + for param_id in definition.parameters.clone() { + let param = &ctx.heap[param_id]; + let param_parser_type_id = param.parser_type; + parameter_types.push(self.determine_inference_type_from_parser_type(ctx, param_parser_type_id, false)); + } + + let return_type = self.determine_inference_type_from_parser_type(ctx, definition.return_type, false); + (parameter_types, return_type) + }, + Definition::Struct(_) | Definition::Enum(_) => { + unreachable!("insert initial polymorph data for struct/enum"); + } + } + } + }; + + self.extra_data.insert(call_id.upcast(), ExtraData { + poly_vars, + embedded: embedded_types, + returned: return_type + }); } + /// Determines the initial InferenceType from the provided ParserType. This + /// may be called with two kinds of intentions: + /// 1. To resolve a ParserType within the body of a function, or on + /// polymorphic arguments to calls/instantiations within that body. This + /// means that the polymorphic variables are known and can be replaced + /// with the monomorph we're instantiating. + /// 2. To resolve a ParserType on a called function's definition or on + /// an instantiated datatype's members. This means that the polymorphic + /// arguments inside those ParserTypes refer to the polymorphic + /// variables in the called/instantiated type's definition. + /// In the second case we place InferenceTypePart::Marker instances such + /// that we can perform type inference on the polymorphic variables. fn determine_inference_type_from_parser_type( - &mut self, ctx: &mut Ctx, parser_type_id: ParserTypeId + &mut self, ctx: &Ctx, parser_type_id: ParserTypeId, + parser_type_in_body: bool ) -> InferenceType { use ParserTypeVariant as PTV; - use InferredPart as IP; + use InferenceTypePart as ITP; let mut to_consider = VecDeque::with_capacity(16); to_consider.push_back(parser_type_id); let mut infer_type = Vec::new(); let mut has_inferred = false; + let mut has_markers = false; while !to_consider.is_empty() { let parser_type_id = to_consider.pop_front().unwrap(); let parser_type = &ctx.heap[parser_type_id]; match &parser_type.variant { - PTV::Message => { infer_type.push(IP::Message); }, - PTV::Bool => { infer_type.push(IP::Bool); }, - PTV::Byte => { infer_type.push(IP::Byte); }, - PTV::Short => { infer_type.push(IP::Short); }, - PTV::Int => { infer_type.push(IP::Int); }, - PTV::Long => { infer_type.push(IP::Long); }, - PTV::String => { infer_type.push(IP::String); }, + PTV::Message => { infer_type.push(ITP::Message); }, + PTV::Bool => { infer_type.push(ITP::Bool); }, + PTV::Byte => { infer_type.push(ITP::Byte); }, + PTV::Short => { infer_type.push(ITP::Short); }, + PTV::Int => { infer_type.push(ITP::Int); }, + PTV::Long => { infer_type.push(ITP::Long); }, + PTV::String => { infer_type.push(ITP::String); }, PTV::IntegerLiteral => { unreachable!("integer literal type on variable type"); }, PTV::Inferred => { - infer_type.push(IP::Unknown); + infer_type.push(ITP::Unknown); has_inferred = true; }, PTV::Array(subtype_id) => { - infer_type.push(IP::Array); + infer_type.push(ITP::Array); to_consider.push_front(*subtype_id); }, PTV::Input(subtype_id) => { - infer_type.push(IP::Input); + infer_type.push(ITP::Input); to_consider.push_front(*subtype_id); }, PTV::Output(subtype_id) => { - infer_type.push(IP::Output); + infer_type.push(ITP::Output); to_consider.push_front(*subtype_id); }, PTV::Symbolic(symbolic) => { @@ -1230,9 +1486,16 @@ impl TypeResolvingVisitor { // Retrieve concrete type of argument and add it to // the inference type. debug_assert!(symbolic.poly_args.is_empty()); // TODO: @hkt - debug_assert!(arg_idx < self.polyvars.len()); - for concrete_part in &self.polyvars[arg_idx].v { - infer_type.push(IP::from(*concrete_part)); + + if parser_type_in_body { + debug_assert!(arg_idx < self.polyvars.len()); + for concrete_part in &self.polyvars[arg_idx].v { + infer_type.push(ITP::from(*concrete_part)); + } + } else { + has_markers = true; + infer_type.push(ITP::Marker(arg_idx)); + infer_type.push(ITP::Unknown); } }, SymbolicParserTypeVariant::Definition(definition_id) => { @@ -1248,7 +1511,7 @@ impl TypeResolvingVisitor { debug_assert_eq!(symbolic.poly_args.len(), num_poly); } - infer_type.push(IP::Instance(definition_id, symbolic.poly_args.len())); + infer_type.push(ITP::Instance(definition_id, symbolic.poly_args.len())); let mut poly_arg_idx = symbolic.poly_args.len(); while poly_arg_idx > 0 { poly_arg_idx -= 1; @@ -1260,7 +1523,7 @@ impl TypeResolvingVisitor { } } - InferenceType::new(false, !has_inferred, infer_type) + InferenceType::new(has_markers, !has_inferred, infer_type) } /// Construct an error when an expression's type does not match. This @@ -1321,23 +1584,8 @@ impl TypeResolvingVisitor { ) } - fn construct_type_class_error( - &self, ctx: &Ctx, expr_id: ExpressionId, type_class: TypeClass - ) -> ParseError2 { - let expr = &ctx.heap[expr_id]; - let expr_type = self.expr_types.get(&expr_id).unwrap(); - - return ParseError2::new_error( - &ctx.module.source, expr.position(), - &format!( - "Incompatible types: got a '{}' but expected a {} type", - expr_type.display_name(&ctx.heap), type_class - ) - ) - } - fn construct_template_type_error( - &self, ctx: &Ctx, expr_id: ExpressionId, template: &[InferenceType] + &self, ctx: &Ctx, expr_id: ExpressionId, template: &[InferenceTypePart] ) -> ParseError2 { // TODO: @cleanup let fake = InferenceType::new(false, false, Vec::from(template)); diff --git a/src/protocol/parser/utils.rs b/src/protocol/parser/utils.rs index 5650c9ba98491e455cc50756aa1501f339135777..a68e03356b52591d295f8764faf987dfd6b221b9 100644 --- a/src/protocol/parser/utils.rs +++ b/src/protocol/parser/utils.rs @@ -15,6 +15,7 @@ pub(crate) enum FindTypeResult<'t, 'i> { SymbolNamespace{ident_pos: InputPosition, symbol_pos: InputPosition}, } +// TODO: @cleanup Find other uses of this pattern impl<'t, 'i> FindTypeResult<'t, 'i> { /// Utility function to transform the `FindTypeResult` into a `Result` where /// `Ok` contains the resolved type, and `Err` contains a `ParseError` which diff --git a/src/protocol/parser/visitor_linker.rs b/src/protocol/parser/visitor_linker.rs index 0b15f8eacffa8101d4be0a2d8c83d13e8f361841..b2d93e0200c3b3298de4a287ecba2cf5af5e1ced 100644 --- a/src/protocol/parser/visitor_linker.rs +++ b/src/protocol/parser/visitor_linker.rs @@ -113,6 +113,15 @@ impl ValidityAndLinkerVisitor { self.parser_type_buffer.clear(); self.insert_buffer.clear(); } + + /// Debug call to ensure that we didn't make any mistakes in any of the + /// employed buffers + fn check_post_definition_state(&self) { + debug_assert!(self.statement_buffer.is_empty()); + debug_assert!(self.expression_buffer.is_empty()); + debug_assert!(self.parser_type_buffer.is_empty()); + debug_assert!(self.insert_buffer.is_empty()); + } } impl Visitor2 for ValidityAndLinkerVisitor { @@ -151,7 +160,10 @@ impl Visitor2 for ValidityAndLinkerVisitor { self.performing_breadth_pass = true; self.visit_stmt(ctx, body_id)?; self.performing_breadth_pass = false; - self.visit_stmt(ctx, body_id) + self.visit_stmt(ctx, body_id)?; + + self.check_post_definition_state(); + Ok(()) } fn visit_function_definition(&mut self, ctx: &mut Ctx, id: FunctionId) -> VisitorResult { @@ -184,7 +196,10 @@ impl Visitor2 for ValidityAndLinkerVisitor { self.performing_breadth_pass = true; self.visit_stmt(ctx, body_id)?; self.performing_breadth_pass = false; - self.visit_stmt(ctx, body_id) + self.visit_stmt(ctx, body_id)?; + + self.check_post_definition_state(); + Ok(()) } //-------------------------------------------------------------------------- @@ -767,8 +782,13 @@ impl Visitor2 for ValidityAndLinkerVisitor { // Resolve the method to the appropriate definition and check the // legality of the particular method call. + // TODO: @cleanup Unify in some kind of signature call, see similar + // cleanup comments with this `match` format. + let num_args; match &mut call_expr.method { - Method::Create => {}, + Method::Create => { + num_args = 1; + }, Method::Fires => { if !self.def_type.is_primitive() { return Err(ParseError2::new_error( @@ -776,6 +796,7 @@ impl Visitor2 for ValidityAndLinkerVisitor { "A call to 'fires' may only occur in primitive component definitions" )); } + num_args = 1; }, Method::Get => { if !self.def_type.is_primitive() { @@ -784,6 +805,7 @@ impl Visitor2 for ValidityAndLinkerVisitor { "A call to 'get' may only occur in primitive component definitions" )); } + num_args = 1; }, Method::Symbolic(symbolic) => { // Find symbolic method @@ -808,11 +830,29 @@ impl Visitor2 for ValidityAndLinkerVisitor { }; symbolic.definition = Some(definition_id); + match ctx.types.get_base_definition(&definition_id).unwrap() { + Definition::Function(definition) => { + num_args = definition.parameters.len(); + }, + _ => unreachable!(), + } } } - // Parse all the arguments in the depth pass as well. Note that we check - // the number of arguments in the type checker. + // Check the poly args and the number of variables in the call + // expression + self.visit_call_poly_args(ctx, id)?; + if call_expr.arguments.len() != num_args { + return Err(ParseError2::new_error( + &ctx.module.source, call_expr.position, + &format!( + "This call expects {} arguments, but {} were provided", + num_args, call_expr.arguments.len() + ) + )); + } + + // Recurse into all of the arguments and set the expression's parent let call_expr = &mut ctx.heap[id]; let upcast_id = id.upcast(); @@ -852,12 +892,114 @@ impl Visitor2 for ValidityAndLinkerVisitor { //-------------------------------------------------------------------------- fn visit_parser_type(&mut self, ctx: &mut Ctx, id: ParserTypeId) -> VisitorResult { - // We visit a particular type rooted in a non-ParserType node in the - // AST. Within this function we set up a buffer to visit all nested - // ParserType nodes. - // The goal is to link symbolic ParserType instances to the appropriate - // definition or symbolic type. Alternatively to throw an error if we - // cannot resolve the ParserType to either of these (polymorphic) types. + let old_num_types = self.parser_type_buffer.len(); + match self.visit_parser_type_without_buffer_cleanup(ctx, id) { + Ok(_) => { + debug_assert_eq!(self.parser_type_buffer.len(), old_num_types); + Ok(()) + }, + Err(err) => { + self.parser_type_buffer.truncate(old_num_types); + Err(err) + } + } + } +} + +enum FindOfTypeResult { + // Identifier was exactly matched, type matched as well + Found(DefinitionId), + // Identifier was matched, but the type differs from the expected one + TypeMismatch(&'static str), + // Identifier could not be found + NotFound, +} + +impl ValidityAndLinkerVisitor { + //-------------------------------------------------------------------------- + // Special traversal + //-------------------------------------------------------------------------- + + /// Will visit a statement with a hint about its wrapping statement. This is + /// used to distinguish block statements with a wrapping synchronous + /// statement from normal block statements. + fn visit_stmt_with_hint(&mut self, ctx: &mut Ctx, id: StatementId, hint: Option) -> VisitorResult { + if let Statement::Block(block_stmt) = &ctx.heap[id] { + let block_id = block_stmt.this; + self.visit_block_stmt_with_hint(ctx, block_id, hint) + } else { + self.visit_stmt(ctx, id) + } + } + + fn visit_block_stmt_with_hint(&mut self, ctx: &mut Ctx, id: BlockStatementId, hint: Option) -> VisitorResult { + if self.performing_breadth_pass { + // Performing a breadth pass, so don't traverse into the statements + // of the block. + return Ok(()) + } + + // Set parent scope and relative position in the parent scope. Remember + // these values to set them back to the old values when we're done with + // the traversal of the block's statements. + let body = &mut ctx.heap[id]; + body.parent_scope = self.cur_scope.clone(); + body.relative_pos_in_parent = self.relative_pos_in_block; + + let old_scope = self.cur_scope.replace(match hint { + Some(sync_id) => Scope::Synchronous((sync_id, id)), + None => Scope::Regular(id), + }); + let old_relative_pos = self.relative_pos_in_block; + + // Copy statement IDs into buffer + let old_num_stmts = self.statement_buffer.len(); + { + let body = &ctx.heap[id]; + self.statement_buffer.extend_from_slice(&body.statements); + } + let new_num_stmts = self.statement_buffer.len(); + + // Perform the breadth-first pass. Its main purpose is to find labeled + // statements such that we can find the `goto`-targets immediately when + // performing the depth pass + self.performing_breadth_pass = true; + for stmt_idx in old_num_stmts..new_num_stmts { + self.relative_pos_in_block = (stmt_idx - old_num_stmts) as u32; + self.visit_stmt(ctx, self.statement_buffer[stmt_idx])?; + } + + if !self.insert_buffer.is_empty() { + let body = &mut ctx.heap[id]; + for (insert_idx, (pos, stmt)) in self.insert_buffer.drain(..).enumerate() { + body.statements.insert(pos as usize + insert_idx, stmt); + } + } + + // And the depth pass. Because we're not actually visiting any inserted + // nodes because we're using the statement buffer, we may safely use the + // relative_pos_in_block counter. + self.performing_breadth_pass = false; + for stmt_idx in old_num_stmts..new_num_stmts { + self.relative_pos_in_block = (stmt_idx - old_num_stmts) as u32; + self.visit_stmt(ctx, self.statement_buffer[stmt_idx])?; + } + + self.cur_scope = old_scope; + self.relative_pos_in_block = old_relative_pos; + + // Pop statement buffer + debug_assert!(self.insert_buffer.is_empty(), "insert buffer not empty after depth pass"); + self.statement_buffer.truncate(old_num_stmts); + + Ok(()) + } + + /// Visits a particular ParserType in the AST and resolves temporary and + /// implicitly inferred types into the appropriate tree. Note that a + /// ParserType node is a tree. Only call this function on the root node of + /// that tree to prevent doing work more than once. + fn visit_parser_type_without_buffer_cleanup(&mut self, ctx: &mut Ctx, id: ParserTypeId) -> VisitorResult { use ParserTypeVariant as PTV; debug_assert!(!self.performing_breadth_pass); @@ -900,7 +1042,7 @@ impl Visitor2 for ValidityAndLinkerVisitor { // TODO: @hkt Maybe allow higher-kinded types? if !symbolic.poly_args.is_empty() { return Err(ParseError2::new_error( - &ctx.module.source, symbolic.identifier.position, + &ctx.module.source, symbolic.identifier.position, "Polymorphic arguments to a polymorphic variable (higher-kinded types) are not allowed (yet)" )); } @@ -931,7 +1073,7 @@ impl Visitor2 for ValidityAndLinkerVisitor { } // If the type is polymorphic then we have two cases: if - // the programmer did not specify the polyargs then we + // the programmer did not specify the polyargs then we // assume we're going to infer all of them. Otherwise we // make sure that they match in count. if !found_type.poly_args.is_empty() && symbolic.poly_args.is_empty() { @@ -945,12 +1087,12 @@ impl Visitor2 for ValidityAndLinkerVisitor { return Err(ParseError2::new_error( &ctx.module.source, symbolic.identifier.position, &format!( - "Expected {} polymorpic arguments (or none, to infer them), but {} were specified", + "Expected {} polymorphic arguments (or none, to infer them), but {} were specified", found_type.poly_args.len(), symbolic.poly_args.len() ) )) } else { - // If here then the type is not polymorphic, or all + // If here then the type is not polymorphic, or all // types are properly specified by the user. for specified_poly_arg in &symbolic.poly_args { self.parser_type_buffer.push(*specified_poly_arg); @@ -993,96 +1135,6 @@ impl Visitor2 for ValidityAndLinkerVisitor { Ok(()) } -} - -enum FindOfTypeResult { - // Identifier was exactly matched, type matched as well - Found(DefinitionId), - // Identifier was matched, but the type differs from the expected one - TypeMismatch(&'static str), - // Identifier could not be found - NotFound, -} - -impl ValidityAndLinkerVisitor { - //-------------------------------------------------------------------------- - // Special traversal - //-------------------------------------------------------------------------- - - /// Will visit a statement with a hint about its wrapping statement. This is - /// used to distinguish block statements with a wrapping synchronous - /// statement from normal block statements. - fn visit_stmt_with_hint(&mut self, ctx: &mut Ctx, id: StatementId, hint: Option) -> VisitorResult { - if let Statement::Block(block_stmt) = &ctx.heap[id] { - let block_id = block_stmt.this; - self.visit_block_stmt_with_hint(ctx, block_id, hint) - } else { - self.visit_stmt(ctx, id) - } - } - - fn visit_block_stmt_with_hint(&mut self, ctx: &mut Ctx, id: BlockStatementId, hint: Option) -> VisitorResult { - if self.performing_breadth_pass { - // Performing a breadth pass, so don't traverse into the statements - // of the block. - return Ok(()) - } - - // Set parent scope and relative position in the parent scope. Remember - // these values to set them back to the old values when we're done with - // the traversal of the block's statements. - let body = &mut ctx.heap[id]; - body.parent_scope = self.cur_scope.clone(); - body.relative_pos_in_parent = self.relative_pos_in_block; - - let old_scope = self.cur_scope.replace(match hint { - Some(sync_id) => Scope::Synchronous((sync_id, id)), - None => Scope::Regular(id), - }); - let old_relative_pos = self.relative_pos_in_block; - - // Copy statement IDs into buffer - let old_num_stmts = self.statement_buffer.len(); - { - let body = &ctx.heap[id]; - self.statement_buffer.extend_from_slice(&body.statements); - } - let new_num_stmts = self.statement_buffer.len(); - - // Perform the breadth-first pass. Its main purpose is to find labeled - // statements such that we can find the `goto`-targets immediately when - // performing the depth pass - self.performing_breadth_pass = true; - for stmt_idx in old_num_stmts..new_num_stmts { - self.relative_pos_in_block = (stmt_idx - old_num_stmts) as u32; - self.visit_stmt(ctx, self.statement_buffer[stmt_idx])?; - } - - if !self.insert_buffer.is_empty() { - let body = &mut ctx.heap[id]; - for (insert_idx, (pos, stmt)) in self.insert_buffer.drain(..).enumerate() { - body.statements.insert(pos as usize + insert_idx, stmt); - } - } - - // And the depth pass. Because we're not actually visiting any inserted - // nodes because we're using the statement buffer, we may safely use the - // relative_pos_in_block counter. - self.performing_breadth_pass = false; - for stmt_idx in old_num_stmts..new_num_stmts { - self.relative_pos_in_block = (stmt_idx - old_num_stmts) as u32; - self.visit_stmt(ctx, self.statement_buffer[stmt_idx])?; - } - - self.cur_scope = old_scope; - self.relative_pos_in_block = old_relative_pos; - - // Pop statement buffer - debug_assert!(self.insert_buffer.is_empty(), "insert buffer not empty after depth pass"); - self.statement_buffer.truncate(old_num_stmts); - - Ok(()) - } //-------------------------------------------------------------------------- // Utilities @@ -1411,4 +1463,73 @@ impl ValidityAndLinkerVisitor { Ok(target) } + + fn visit_call_poly_args(&mut self, ctx: &mut Ctx, call_id: CallExpressionId) -> VisitorResult { + let call_expr = &ctx.heap[call_id]; + + // Determine the polyarg signature + let num_expected_poly_args = match &call_expr.method { + Method::Create => { + 0 + }, + Method::Fires => { + 1 + }, + Method::Get => { + 1 + }, + Method::Symbolic(symbolic) => { + let definition = &ctx.heap[symbolic.definition.unwrap()]; + if let Definition::Function(definition) = definition { + definition.poly_vars.len() + } else { + debug_assert!(false, "expected function while visiting call poly args"); + unreachable!(); + } + } + }; + + // We allow zero polyargs to imply all args are inferred. Otherwise the + // number of arguments must be equal + if call_expr.poly_args.is_empty() { + if num_expected_poly_args != 0 { + // Infer all polyargs + // TODO: @cleanup Not nice to use method position as implicitly + // inferred parser type pos. + let pos = call_expr.position(); + for _ in 0..num_expected_poly_args { + self.parser_type_buffer.push(ctx.heap.alloc_parser_type(|this| ParserType { + this, + pos, + variant: ParserTypeVariant::Inferred, + })); + } + + let call_expr = &mut ctx.heap[call_id]; + call_expr.poly_args.reserve(num_expected_poly_args); + for _ in 0..num_expected_poly_args { + call_expr.poly_args.push(self.parser_type_buffer.pop().unwrap()); + } + } + Ok(()) + } else if call_expr.poly_args.len() == num_expected_poly_args { + // Number of args is not 0, so parse all the specified ParserTypes + let old_num_types = self.parser_type_buffer.len(); + self.parser_type_buffer.extend(&call_expr.poly_args); + while self.parser_type_buffer.len() > old_num_types { + let parser_type_id = self.parser_type_buffer.pop().unwrap(); + self.visit_parser_type(ctx, parser_type_id); + } + self.parser_type_buffer.truncate(old_num_types); + Ok(()) + } else { + return Err(ParseError2::new_error( + &ctx.module.source, call_expr.position, + &format!( + "Expected {} polymorphic arguments (or none, to infer them), but {} were specified", + num_expected_poly_args, call_expr.poly_args.len() + ) + )); + } + } } \ No newline at end of file