From d673a04baa221e9a6dbcf081cb46a0b73ec4a723 2021-03-17 15:49:28 From: MH Date: 2021-03-17 15:49:28 Subject: [PATCH] WIP on 'simple' type inference --- diff --git a/src/protocol/parser/type_resolver.rs b/src/protocol/parser/type_resolver.rs index 1ab3ee6fa37ad31bb0883d97d5155911f571ba7b..85d522f635c47efe7a6692254fe0c48f517e20af 100644 --- a/src/protocol/parser/type_resolver.rs +++ b/src/protocol/parser/type_resolver.rs @@ -1,4 +1,4 @@ -use std::{collections::{HashMap, VecDeque}, fmt::Display}; +use std::collections::{HashMap, HashSet, VecDeque}; use crate::protocol::ast::*; use crate::protocol::inputsource::*; @@ -19,6 +19,7 @@ pub(crate) enum InferredPart { // Special cases Void, // For builtin functions without a return type. Result of a "call to a component" IntegerLike, // For integer literals without a concrete type + ArrayLike, // Array or Slice // No subtypes Message, Bool, @@ -44,6 +45,14 @@ impl InferredPart { _ => false } } + + fn is_concrete_arraylike(&self) -> bool { + use InferredPart as IP; + match self { + IP::Slice | IP::Array => true, + _ => false + } + } } impl From for InferredPart { @@ -73,6 +82,11 @@ pub(crate) struct InferenceType { parts: Vec, } +const BOOL_TEMPLATE: InferenceType = InferenceType{ done: true, parts: vec![InferredPart::Bool] }; +const INTEGERLIKE_TEMPLATE: InferenceType = InferenceType{ done: false, parts: vec![InferredPart::IntegerLike] }; +const ARRAYLIKE_TEMPLATE: InferenceType = InferenceType{ done: false, parts: vec![InferredPart::ArrayLike, InferredPart::Unknown] }; +const SLICE_TEMPLATE: InferenceType = InferenceType{ done: false, parts: vec![InferredPart::Slice, InferredPart::Unknown] }; + impl InferenceType { fn new(done: bool, inferred: Vec) -> Self { Self{ done, parts: inferred } @@ -90,7 +104,7 @@ impl InferenceType { IP::String => { depth -= 1; }, - IP::Array | IP::Slice | IP::Input | IP::Output => { + IP::ArrayLike | IP::Array | IP::Slice | IP::Input | IP::Output => { // depth remains unaltered }, IP::Instance(_, num_sub) => { @@ -122,6 +136,17 @@ impl InferenceType { self.might_be_numeric() } + fn might_be_boolean(&self) -> bool { + use InferredPart as IP; + debug_assert!(!self.parts.is_empty()); + if self.parts.len() != 1 { return false; } + match self.parts[0] { + IP::Unknown | IP::Bool => true, + _ => + false + } + } + fn display_name(&self, heap: &Heap) -> String { use InferredPart as IP; @@ -137,6 +162,11 @@ impl InferenceType { IP::Int => v.push_str("int"), IP::Long => v.push_str("long"), IP::String => v.push_str("str"), + IP::ArrayLike => { + *idx += 1; + write_recursive(v, t, h, idx); + v.push_str("[?]"); + } IP::Array => { *idx += 1; write_recursive(v, t, h, idx); @@ -207,6 +237,18 @@ impl InferenceResult { _ => false } } + fn modified_rhs(&self) -> bool { + match self { + InferenceResult::Second | InferenceResult::Both => true, + _ => false + } + } +} + +fn progress_inference_part_and_advance_idx( + a: &mut InferenceType, b: &mut InferenceType, iter_idx: &mut usize +) -> Result<(), InferenceResult> { + } // Attempts to infer types within two `InferenceType` instances. If they are @@ -241,13 +283,16 @@ unsafe fn progress_inference_types(a: *mut InferenceType, b: *mut InferenceType) // Not an exact match, so deal with edge cases // - inference of integerlike types to conrete integers - if *a_part == InferredPart::IntegerLike && b_part.is_concrete_int() { + // - inference of arraylike to array/slice + if (*a_part == InferredPart::IntegerLike && b_part.is_concrete_int()) || + (*a_part == InferredPart::ArrayLike && b_part.is_concrete_arraylike()) { modified_a = true; a.parts[iter_idx] = b_part.clone(); iter_idx += 1; continue; } - if *b_part == InferredPart::IntegerLike && a_part.is_concrete_int() { + if (*b_part == InferredPart::IntegerLike && a_part.is_concrete_int()) || + (*b_part == InferredPart::ArrayLike && a_part.is_concrete_arraylike()){ modified_b = true; b.parts[iter_idx] = a_part.clone(); iter_idx += 1; @@ -283,7 +328,7 @@ unsafe fn progress_inference_types(a: *mut InferenceType, b: *mut InferenceType) return InferenceResult::Incompatible; } - // TODO: @performance, can be done inline + // TODO: @performance, can be done in the loop above a.done = true; for part in &a.parts { if *part == InferredPart::Unknown || *part == InferredPart::IntegerLike { @@ -312,6 +357,69 @@ unsafe fn progress_inference_types(a: *mut InferenceType, b: *mut InferenceType) } } +enum InferenceTemplateResult { + Unmodified, + Modified, + Incompatible, +} + +fn progress_inference_with_template(t: &mut InferenceType, template: &InferenceType) -> InferenceTemplateResult { + debug_assert!(!t.parts.is_empty()); + debug_assert!(!template.parts.is_empty()); + + // Iterate over the elements of both types, enforce the template onto the + // provided mutable inference type. + let mut modified = false; + let mut iter_idx = 0; + + while iter_idx < t.parts.len() { + let part = &t.parts[iter_idx]; + let template_part = &template.parts[iter_idx]; + + if part == template_part { + iter_idx += 1; + continue + } + + // Check if we can infer anything from the template + if (*part == InferredPart::IntegerLike && template_part.is_concrete_int()) || + (*part == InferredPart::ArrayLike && template_part.is_concrete_arraylike()) { + modified = true; + t.parts[iter_idx] = template_part.clone(); + iter_idx += 1; + continue; + } + + if *part == InferredPart::Unknown { + let end_idx = template.find_subtree_end_idx(iter_idx); + t.parts[iter_idx] = template.parts[iter_idx].clone(); + for insert_idx in (iter_idx + 1)..end_idx { + t.parts.insert(insert_idx, template.parts[insert_idx].clone()); + } + + modified = true; + iter_idx = end_idx; + continue; + } + + // Because we're dealing with a template some mismatched parts are still + // valid + if (part.is_concrete_arraylike() && *template_part == InferredPart::ArrayLike) || + (part.is_concrete_int() && *template_part == InferredPart::IntegerLike) { + iter_idx += 1; + continue; + } + + return InferenceTemplateResult::Incompatible; + } + + if modified { + InferenceTemplateResult::Modified + } else { + InferenceTemplateResult::Unmodified + } +} + enum DefinitionType{ None, Component(ComponentId), @@ -353,6 +461,7 @@ pub(crate) struct TypeResolvingVisitor { // specify these types until we're stuck or we've fully determined the type. infer_types: HashMap, expr_types: HashMap, + expr_queued: HashSet, } impl TypeResolvingVisitor { @@ -364,6 +473,7 @@ impl TypeResolvingVisitor { polyvars: Vec::new(), infer_types: HashMap::new(), expr_types: HashMap::new(), + expr_queued: HashSet::new(), } } @@ -549,7 +659,7 @@ impl Visitor2 for TypeResolvingVisitor { self.visit_expr(ctx, left_expr_id)?; self.visit_expr(ctx, right_expr_id)?; - // TODO: Initial progress? + self.progress_assignment_expr(ctx, id) } fn visit_conditional_expr(&mut self, ctx: &mut Ctx, id: ConditionalExpressionId) -> VisitorResult { @@ -566,21 +676,50 @@ impl Visitor2 for TypeResolvingVisitor { self.visit_expr(ctx, true_expr_id)?; self.visit_expr(ctx, false_expr_id)?; - // TODO: Initial progress? - Ok(()) + self.progress_conditional_expr(ctx, id) + } + + fn visit_binary_expr(&mut self, ctx: &mut Ctx, id: BinaryExpressionId) -> VisitorResult { + let upcast_id = id.upcast(); + self.insert_initial_expr_inference_type(ctx, upcast_id); + + let binary_expr = &ctx.heap[id]; + let lhs_expr_id = binary_expr.left; + let rhs_expr_id = binary_expr.right; + + self.visit_expr(ctx, lhs_expr_id)?; + self.visit_expr(ctx, rhs_expr_id)?; + + self.progress_binary_expr(ctx, id) + } + + fn visit_unary_expr(&mut self, ctx: &mut Ctx, id: UnaryExpressionId) -> VisitorResult { + let upcast_id = id.upcast(); + self.insert_initial_expr_inference_type(ctx, upcast_id); + + let unary_expr = &ctx.heap[id]; + let arg_expr_id = unary_expr.expression; + + self.visit_expr(ctx, arg_expr_id); + + self.progress_unary_expr(ctx, id) } } +// TODO: @cleanup Decide to use this where appropriate or to make templates for +// everything enum TypeClass { Numeric, // int and float - Integer, + Integer, // only ints + Boolean, // only boolean } impl std::fmt::Display for TypeClass { fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { write!(f, "{}", match self { TypeClass::Numeric => "numeric", - TypeClass::Integer => "integer" + TypeClass::Integer => "integer", + TypeClass::Boolean => "boolean", }) } } @@ -614,7 +753,7 @@ enum TypeConstraintResult { } impl TypeResolvingVisitor { - fn progress_assignment_expr(&mut self, ctx: &mut Ctx, id: AssignmentExpressionId) -> Result { + fn progress_assignment_expr(&mut self, ctx: &mut Ctx, id: AssignmentExpressionId) -> Result<(), ParseError2> { use AssignmentOperator as AO; // TODO: Assignable check @@ -634,13 +773,182 @@ impl TypeResolvingVisitor { }; let upcast_id = id.upcast(); - let made_progress = self.apply_equal3_constraint(ctx, upcast_id, arg1_expr_id, arg2_expr_id)?; + let (progress_expr, progress_arg1, progress_arg2) = self.apply_equal3_constraint( + ctx, upcast_id, arg1_expr_id, arg2_expr_id + )?; if let Some(type_class) = type_class { self.expr_type_is_of_type_class(ctx, id.upcast(), type_class)? } - Ok(made_progress) + if progress_expr { self.queue_expr_parent(ctx, upcast_id); } + if progress_arg1 { self.queue_expr(arg1_expr_id); } + if progress_arg2 { self.queue_expr(arg2_expr_id); } + + Ok(()) + } + + fn progress_conditional_expr(&mut self, ctx: &mut Ctx, id: ConditionalExpressionId) -> Result<(), ParseError2> { + // Note: test expression type is already enforced + let upcast_id = id.upcast(); + let expr = &ctx.heap[id]; + let arg1_expr_id = expr.true_expression; + let arg2_expr_id = expr.false_expression; + + let (progress_expr, progress_arg1, progress_arg2) = self.apply_equal3_constraint( + ctx, upcast_id, arg1_expr_id, arg2_expr_id + )?; + + if progress_expr { self.queue_expr_parent(ctx, upcast_id); } + if progress_arg1 { self.queue_expr(arg1_expr_id); } + if progress_arg2 { self.queue_expr(arg2_expr_id); } + + Ok(()) + } + + fn progress_binary_expr(&mut self, ctx: &mut Ctx, id: BinaryExpressionId) -> Result<(), ParseError2> { + // Note: our expression type might be fixed by our parent, but we still + // need to make sure it matches the type associated with our operation. + use BinaryOperator as BO; + + let upcast_id = id.upcast(); + let expr = &ctx.heap[id]; + let arg1_id = expr.left; + let arg2_id = expr.right; + + let (progress_expr, progress_arg1, progress_arg2) = match expr.operation { + BO::Concatenate => { + // Equal3 with array-like appearance, two slice arguments result + // in an array output + todo!("implement concatenate typechecking"); + (false, false, false) + }, + BO::LogicalOr | BO::LogicalAnd => { + // Forced boolean on all + let progress_expr = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?; + let progress_arg1 = self.apply_forced_constraint(ctx, arg1_id, &BOOL_TEMPLATE)?; + let progress_arg2 = self.apply_forced_constraint(ctx, arg2_id, &BOOL_TEMPLATE)?; + + (progress_expr, progress_arg1, progress_arg2) + }, + BO::BitwiseOr | BO::BitwiseXor | BO::BitwiseAnd | BO::Remainder | BO::ShiftLeft | BO::ShiftRight => { + let result = self.apply_equal3_constraint(ctx, upcast_id, arg1_id, arg2_id)?; + self.expr_type_is_of_type_class(ctx, upcast_id, TypeClass::Integer)?; + result + }, + BO::Equality | BO::Inequality | BO::LessThan | BO::GreaterThan | BO::LessThanEqual | BO::GreaterThanEqual => { + // Equal2 on args, forced boolean output + let progress_expr = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?; + let (progress_arg1, progress_arg2) = self.apply_equal2_constraint(ctx, upcast_id, arg1_id, arg2_id); + self.expr_type_is_of_type_class(ctx, arg1_id, TypeClass::Numeric)?; + + (progress_expr, progress_arg1, progress_arg2) + }, + BO::Add | BO::Subtract | BO::Multiply | BO::Divide => { + let result = self.apply_equal3_constraint(ctx, upcast_id, arg1_id, arg2_id)?; + self.expr_type_is_of_type_class(ctx, upcast_id, TypeClass::Numeric)?; + result + }, + }; + + if progress_expr { self.queue_expr_parent(ctx, upcast_id); } + if progress_arg1 { self.queue_expr(arg1_id); } + if progress_arg2 { self.queue_expr(arg2_id); } + + Ok(()) + } + + fn progress_unary_expr(&mut self, ctx: &mut Ctx, id: UnaryExpressionId) -> Result<(), ParseError2> { + use UnaryOperation as UO; + + let upcast_id = id.upcast(); + let expr = &ctx.heap[id]; + let arg_id = expr.expression; + + let (progress_expr, progress_arg) = match expr.operation { + UO::Positive | UO::Negative => { + // Equal types of numeric class + let progress = self.apply_equal2_constraint(ctx, upcast_id, upcast_id, arg_id)?; + self.expr_type_is_of_type_class(ctx, upcast_id, TypeClass::Numeric)?; + progress + }, + UO::BitwiseNot | UO::PreIncrement | UO::PreDecrement | UO::PostIncrement | UO::PostDecrement => { + // Equal types of integer class + let progress = self.apply_equal2_constraint(ctx, upcast_id, upcast_id, arg_id)?; + self.expr_type_is_of_type_class(ctx, upcast_id, TypeClass::Integer)?; + progress + }, + UO::LogicalNot => { + // Both booleans + let progress_expr = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?; + let progress_arg = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?; + (progress_expr, progress_arg) + } + }; + + if progress_expr { self.queue_expr_parent(ctx, upcast_id); } + if progress_arg { self.queue_expr(arg_id); } + + Ok(()) + } + + fn progress_indexing_expr(&mut self, ctx: &mut Ctx, id: IndexingExpressionId) -> Result<(), ParseError2> { + // TODO: Indexable check + let upcast_id = id.upcast(); + let expr = &ctx.heap[id]; + let subject_id = expr.subject; + let index_id = expr.index; + + let progress_subject = self.apply_forced_constraint(ctx, subject_id, &ARRAYLIKE_TEMPLATE)?; + let progress_index = self.apply_forced_constraint(ctx, index_id, &INTEGERLIKE_TEMPLATE)?; + + } + + fn queue_expr_parent(&mut self, ctx: &Ctx, expr_id: ExpressionId) { + if let ExpressionParent::Expression(parent_expr_id, _) = &ctx.heap[expr_id].parent() { + self.expr_queued.insert(*parent_expr_id) + } + } + + fn queue_expr(&mut self, expr_id: ExpressionId) { + self.expr_queued.insert(expr_id); + } + + /// Applies a forced type constraint: the type associated with the supplied + /// expression will be molded into the provided "template". The template may + /// be fully specified (e.g. a bool) or contain "inference" variables (e.g. + /// an array of T) + fn apply_forced_constraint( + &mut self, ctx: &mut Ctx, expr_id: ExpressionId, template: &InferenceType + ) -> Result { + debug_assert_expr_ids_unique_and_known!(self, expr_id); + let expr_type = self.expr_types.get_mut(&expr_id).unwrap(); + match progress_inference_with_template(expr_type, template) { + InferenceTemplateResult::Modified => Ok(true), + InferenceTemplateResult::Unmodified => Ok(false), + InferenceTemplateResult::Incompatible => Err( + self.construct_template_type_error(ctx, expr_id, template) + ) + } + } + + /// Applies a type constraint that expects the two provided types to be + /// equal. We attempt to make progress in inferring the types. If the call + /// is successful then the composition of all types are made equal. + /// The "parent" `expr_id` is provided to construct errors. + fn apply_equal2_constraint( + &mut self, ctx: &mut Ctx, expr_id: ExpressionId, arg1_id: ExpressionId, arg2_id: ExpressionId + ) -> Result<(bool, bool), ParseError2> { + debug_assert_expr_ids_unique_and_known!(self, arg1_id, arg2_id); + let arg1_type: *mut _ = self.expr_types.get_mut(&arg1_id).unwrap(); + let arg2_type: *mut _ = self.expr_types.get_mut(&arg2_id).unwrap(); + + let infer_res = unsafe{ progress_inference_types(expr_idarg1_type, arg2_type) }; + if infer_res == InferenceResult::Incompatible { + return Err(self.construct_arg_type_error(ctx, expr_id, arg1_id, arg2_id)); + } + + Ok((infer_res.modified_lhs(), infer_res.modified_rhs())) } /// Applies a type constraint that expects all three provided types to be @@ -650,7 +958,7 @@ impl TypeResolvingVisitor { fn apply_equal3_constraint( &mut self, ctx: &mut Ctx, expr_id: ExpressionId, arg1_id: ExpressionId, arg2_id: ExpressionId - ) -> Result { + ) -> Result<(bool, bool, bool), ParseError2> { // Safety: all expression IDs are always distinct, and we do not modify // the container debug_assert_expr_ids_unique_and_known!(self, expr_id, arg1_id, arg2_id); @@ -665,20 +973,25 @@ impl TypeResolvingVisitor { let args_res = unsafe{ progress_inference_types(arg1_type, arg2_type) }; if args_res == InferenceResult::Incompatible { - return Err(self.construct_arg_type_error(ctx, expr_id, arg1_id, arg2_id)); + return Err(self.construct_arg_type_error(ctx, expr_id, arg1_id, arg2_id)); } // If all types are compatible, but the second call caused the arg1_type // to be expanded, then we must also assign this to expr_type. + let mut progress_expr = expr_res.modified_lhs(); + let mut progress_arg1 = expr_res.modified_rhs(); + let mut progress_arg2 = args_res.modified_rhs(); + if args_res.modified_lhs() { unsafe { (*expr_type).parts.clear(); (*expr_type).parts.extend((*arg2_type).parts.iter()); } + progress_expr = true; + progress_arg1 = true; } - let made_progress = expr_res.modified_any() || args_res.modified_any(); - Ok(made_progress) + Ok((progress_expr, progress_arg1, progress_arg2)) } /// Applies a typeclass constraint: checks if the type is of a particular @@ -691,7 +1004,8 @@ impl TypeResolvingVisitor { let is_ok = match type_class { TypeClass::Numeric => expr_type.might_be_numeric(), - TypeClass::Integer => expr_type.might_be_integer() + TypeClass::Integer => expr_type.might_be_integer(), + TypeClass::Boolean => expr_type.might_be_boolean(), }; if is_ok { @@ -907,4 +1221,19 @@ impl TypeResolvingVisitor { ) ) } + + fn construct_template_type_error( + &self, ctx: &Ctx, expr_id: ExpressionId, template: &InferenceType + ) -> ParseError2 { + let expr = &ctx.heap[expr_id]; + let expr_type = self.expr_types.get(&expr_id).unwrap(); + + return ParseError2::new_error( + &ctx.module.source, expr.position(), + &format!( + "Incompatible types: got a '{}' but expected a '{}'", + expr_type.display_name(&ctx.heap), template.display_name(&ctx.heap) + ) + ) + } } \ No newline at end of file