Changeset - 393872d4a8f8
[Not reviewed]
0 1 0
MH - 4 years ago 2021-03-17 12:23:30
henger@cwi.nl
yet another bit of progress on type inference
1 file changed with 168 insertions and 42 deletions:
0 comments (0 inline, 0 general)
src/protocol/parser/type_resolver.rs
Show inline comments
 
use std::collections::{HashMap, VecDeque};
 
use std::{collections::{HashMap, VecDeque}, fmt::Display};
 

	
 
use crate::protocol::ast::*;
 
use crate::protocol::inputsource::*;
 
@@ -121,14 +121,71 @@ impl InferenceType {
 
        // TODO: @float Once floats are implemented this is no longer true
 
        self.might_be_numeric()
 
    }
 

	
 
    fn display_name(&self, heap: &Heap) -> String {
 
        use InferredPart as IP;
 

	
 
        fn write_recursive(v: &mut String, t: &InferenceType, h: &Heap, idx: &mut usize) {
 
            match &t.parts[*idx] {
 
                IP::Unknown => v.push_str("?"),
 
                IP::Void => v.push_str("void"),
 
                IP::IntegerLike => v.push_str("int?"),
 
                IP::Message => v.push_str("msg"),
 
                IP::Bool => v.push_str("bool"),
 
                IP::Byte => v.push_str("byte"),
 
                IP::Short => v.push_str("short"),
 
                IP::Int => v.push_str("int"),
 
                IP::Long => v.push_str("long"),
 
                IP::String => v.push_str("str"),
 
                IP::Array => {
 
                    *idx += 1;
 
                    write_recursive(v, t, h, idx);
 
                    v.push_str("[]");
 
                },
 
                IP::Slice => {
 
                    *idx += 1;
 
                    write_recursive(v, t, h, idx);
 
                    v.push_str("[..]")
 
                },
 
                IP::Input => {
 
                    *idx += 1;
 
                    v.push_str("in<");
 
                    write_recursive(v, t, h, idx);
 
                    v.push('>');
 
                },
 
                IP::Output => {
 
                    *idx += 1;
 
                    v.push_str("out<");
 
                    write_recursive(v, t, h, idx);
 
                    v.push('>');
 
                },
 
                IP::Instance(definition_id, num_sub) => {
 
                    let definition = &h[*definition_id];
 
                    v.push_str(&String::from_utf8_lossy(&definition.identifier().value));
 
                    if *num_sub > 0 {
 
                        v.push('<');
 
                        *idx += 1;
 
                        write_recursive(v, t, h, idx);
 
                        for sub_idx in 1..*num_sub {
 
                            *idx += 1;
 
                            v.push_str(", ");
 
                            write_recursive(v, t, h, idx);
 
                        }
 
                        v.push('>');
 
                    }
 
                },
 
            }
 
        }
 

	
 
impl std::fmt::Display for InferenceType {
 
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
 
        fn write_recursive(arg: )
 
        let mut buffer = String::with_capacity(self.parts.len() * 5);
 
        let mut idx = 0;
 
        write_recursive(&mut buffer, self, heap, &mut idx);
 

	
 
        buffer
 
    }
 
}
 

	
 
#[derive(PartialEq, Eq)]
 
enum InferenceResult {
 
    Neither,        // neither argument is clarified
 
    First,          // first argument is clarified using the second one
 
@@ -157,9 +214,11 @@ impl InferenceResult {
 
// After successful inference the parts of the inferred type have equal length.
 
//
 
// Personal note: inference is not the copying of types: this algorithm must
 
// infer that `TypeOuter<TypeA, auto>` and `Struct<auto, TypeB>` resolves to
 
// infer that `TypeOuter<TypeA, auto>` and `TypeOuter<auto, TypeB>` resolves to
 
// `TypeOuter<TypeA, TypeB>`.
 
unsafe fn progress_inference_types(a: *mut InferenceType, b: *mut InferenceType) -> InferenceResult {
 
    let a = &mut *a;
 
    let b = &mut *b;
 
    debug_assert!(!a.parts.is_empty());
 
    debug_assert!(!b.parts.is_empty());
 

	
 
@@ -182,13 +241,13 @@ unsafe fn progress_inference_types(a: *mut InferenceType, b: *mut InferenceType)
 

	
 
        // Not an exact match, so deal with edge cases
 
        // - inference of integerlike types to conrete integers
 
        if a_part == InferredPart::IntegerLike && b_part.is_concrete_int() {
 
        if *a_part == InferredPart::IntegerLike && b_part.is_concrete_int() {
 
            modified_a = true;
 
            a.parts[iter_idx] = b_part.clone();
 
            iter_idx += 1;
 
            continue;
 
        }
 
        if b_part == InferredPart::IntegerLike && a_part.is_concrete_int() {
 
        if *b_part == InferredPart::IntegerLike && a_part.is_concrete_int() {
 
            modified_b = true;
 
            b.parts[iter_idx] = a_part.clone();
 
            iter_idx += 1;
 
@@ -196,7 +255,7 @@ unsafe fn progress_inference_types(a: *mut InferenceType, b: *mut InferenceType)
 
        }
 

	
 
        // - inference of unknown type
 
        if a_part == InferredPart::Unknown {
 
        if *a_part == InferredPart::Unknown {
 
            let end_idx = b.find_subtree_end_idx(iter_idx);
 
            a.parts[iter_idx] = b.parts[iter_idx].clone();
 
            for insert_idx in (iter_idx + 1)..end_idx {
 
@@ -208,7 +267,7 @@ unsafe fn progress_inference_types(a: *mut InferenceType, b: *mut InferenceType)
 
            continue;
 
        }
 

	
 
        if b_part == InferredPart::Unknown {
 
        if *b_part == InferredPart::Unknown {
 
            let end_idx = a.find_subtree_end_idx(iter_idx);
 
            b.parts[iter_idx] = a.parts[iter_idx].clone();
 
            for insert_idx in (iter_idx + 1)..end_idx {
 
@@ -227,7 +286,7 @@ unsafe fn progress_inference_types(a: *mut InferenceType, b: *mut InferenceType)
 
    // TODO: @performance, can be done inline
 
    a.done = true;
 
    for part in &a.parts {
 
        if part == InferredPart::Unknown {
 
        if *part == InferredPart::Unknown || *part == InferredPart::IntegerLike {
 
            a.done = false;
 
            break
 
        }
 
@@ -235,7 +294,7 @@ unsafe fn progress_inference_types(a: *mut InferenceType, b: *mut InferenceType)
 

	
 
    b.done = true;
 
    for part in &b.parts {
 
        if part == InferredPart::Unknown {
 
        if *part == InferredPart::Unknown || *part == InferredPart::IntegerLike {
 
            b.done = false;
 
            break;
 
        }
 
@@ -447,7 +506,7 @@ impl Visitor2 for TypeResolvingVisitor {
 
        let assert_stmt = &ctx.heap[id];
 
        let test_expr_id = assert_stmt.expression;
 

	
 
        self.visit_expr(ctx, expr_id)
 
        self.visit_expr(ctx, test_expr_id)
 
    }
 

	
 
    fn visit_new_stmt(&mut self, ctx: &mut Ctx, id: NewStatementId) -> VisitorResult {
 
@@ -481,7 +540,7 @@ impl Visitor2 for TypeResolvingVisitor {
 

	
 
    fn visit_assignment_expr(&mut self, ctx: &mut Ctx, id: AssignmentExpressionId) -> VisitorResult {
 
        let upcast_id = id.upcast();
 
        self.expr_types.insert(upcast_id, self.determine_initial_expr_inference_type(ctx, upcast_id));
 
        self.insert_initial_expr_inference_type(ctx, upcast_id);
 

	
 
        let assign_expr = &ctx.heap[id];
 
        let left_expr_id = assign_expr.left;
 
@@ -495,7 +554,7 @@ impl Visitor2 for TypeResolvingVisitor {
 

	
 
    fn visit_conditional_expr(&mut self, ctx: &mut Ctx, id: ConditionalExpressionId) -> VisitorResult {
 
        let upcast_id = id.upcast();
 
        self.expr_types.insert(upcast_id, self.determine_initial_expr_inference_type(ctx, upcast_id));
 
        self.insert_initial_expr_inference_type(ctx, upcast_id);
 

	
 
        let conditional_expr = &ctx.heap[id];
 
        let test_expr_id = conditional_expr.test;
 
@@ -517,24 +576,33 @@ enum TypeClass {
 
    Integer,
 
}
 

	
 
impl std::fmt::Display for TypeClass {
 
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
 
        write!(f, "{}", match self {
 
            TypeClass::Numeric => "numeric",
 
            TypeClass::Integer => "integer"
 
        })
 
    }
 
}
 

	
 
macro_rules! debug_assert_expr_ids_unique_and_known {
 
    // Base case for a single expression ID
 
    ($id:ident) => {
 
    ($resolver:ident, $id:ident) => {
 
        if cfg!(debug_assertions) {
 
            self.expr_types.contains_key(&$id);
 
            $resolver.expr_types.contains_key(&$id);
 
        }
 
    };
 
    // Base case for two expression IDs
 
    ($id1:ident, $id2:ident) => {
 
    ($resolver:ident, $id1:ident, $id2:ident) => {
 
        debug_assert_ne!($id1, $id2);
 
        debug_assert_expr_id!($id1);
 
        debug_assert_expr_id!($id2);
 
        debug_assert_expr_ids_unique_and_known!($resolver, $id1);
 
        debug_assert_expr_ids_unique_and_known!($resolver, $id2);
 
    };
 
    // Generic case
 
    ($id1:ident, $id2:ident, $($tail:ident),+) => {
 
    ($resolver:ident, $id1:ident, $id2:ident, $($tail:ident),+) => {
 
        debug_assert_ne!($id1, $id2);
 
        debug_assert_expr_id!($id1);
 
        debug_assert_expr_id!($id2, $($tail),+);
 
        debug_assert_expr_ids_unique_and_known!($resolver, $id1);
 
        debug_assert_expr_ids_unique_and_known!($resolver, $id2, $($tail),+);
 
    };
 
}
 

	
 
@@ -546,7 +614,7 @@ enum TypeConstraintResult {
 
}
 

	
 
impl TypeResolvingVisitor {
 
    fn progress_assignment_expr(&mut self, ctx: &mut Ctx, id: AssignmentExpressionId) {
 
    fn progress_assignment_expr(&mut self, ctx: &mut Ctx, id: AssignmentExpressionId) -> Result<bool, ParseError2> {
 
        use AssignmentOperator as AO;
 

	
 
        // TODO: Assignable check
 
@@ -566,7 +634,13 @@ impl TypeResolvingVisitor {
 
        };
 

	
 
        let upcast_id = id.upcast();
 
        self.apply_equal3_constraint(ctx, upcast_id, arg1_expr_id, arg2_expr_id);
 
        let made_progress = self.apply_equal3_constraint(ctx, upcast_id, arg1_expr_id, arg2_expr_id)?;
 

	
 
        if let Some(type_class) = type_class {
 
            self.expr_type_is_of_type_class(ctx, id.upcast(), type_class)?
 
        }
 

	
 
        Ok(made_progress)
 
    }
 

	
 
    /// Applies a type constraint that expects all three provided types to be
 
@@ -579,42 +653,51 @@ impl TypeResolvingVisitor {
 
    ) -> Result<bool, ParseError2> {
 
        // Safety: all expression IDs are always distinct, and we do not modify
 
        //  the container
 
        debug_assert_expr_ids_unique_and_known!(id1, id2, id3);
 
        debug_assert_expr_ids_unique_and_known!(self, expr_id, arg1_id, arg2_id);
 
        let expr_type: *mut _ = self.expr_types.get_mut(&expr_id).unwrap();
 
        let arg1_type: *mut _ = self.expr_types.get_mut(&arg1_id).unwrap();
 
        let arg2_type: *mut _ = self.expr_types.get_mut(&arg2_id).unwrap();
 

	
 
        let expr_res = unsafe{ progress_inference_types(expr_type, arg1_type) };
 
        if expr_res == InferenceResult::Incompatible { return TypeConstraintResult::ErrExprType; }
 
        if expr_res == InferenceResult::Incompatible { 
 
            return Err(self.construct_expr_type_error(ctx, expr_id, arg1_id));
 
        }
 

	
 
        let args_res = unsafe{ progress_inference_types(arg1_type, arg2_type) };
 
        if args_res == InferenceResult::Incompatible { return TypeConstraintResult::ErrArgType; }
 
        if args_res == InferenceResult::Incompatible { 
 
            return Err(self.construct_arg_type_error(ctx, expr_id, arg1_id, arg2_id)); 
 
        }
 

	
 
        // If all types are compatible, but the second call caused type2 to be
 
        // expanded, then we must also re-expand type1.
 
        // If all types are compatible, but the second call caused the arg1_type
 
        // to be expanded, then we must also assign this to expr_type.
 
        if args_res.modified_lhs() { 
 
            type1.parts.clear();
 
            type1.parts.extend(&type2.parts);
 
            unsafe {
 
                (*expr_type).parts.clear();
 
                (*expr_type).parts.extend((*arg2_type).parts.iter());
 
            }
 

	
 
        if expr_res.modified_any() || args_res.modified_any() {
 
            TypeConstraintResult::Progress
 
        } else {
 
            TypeConstraintResult::NoProgess
 
        }
 

	
 
        let made_progress = expr_res.modified_any() || args_res.modified_any();
 
        Ok(made_progress)
 
    }
 

	
 
    /// Applies a typeclass constraint: checks if the type is of a particular
 
    /// class or not
 
    fn expr_type_is_of_type_class(
 
        &mut self, ctx: &mut Ctx, expr_id: ExpressionId, type_class: TypeClass
 
    ) -> bool {
 
        debug_assert_expr_ids_unique_and_known!(expr_id);
 
    ) -> Result<(), ParseError2> {
 
        debug_assert_expr_ids_unique_and_known!(self, expr_id);
 
        let expr_type = self.expr_types.get(&expr_id).unwrap();
 

	
 
        match type_class {
 
        let is_ok = match type_class {
 
            TypeClass::Numeric => expr_type.might_be_numeric(),
 
            TypeClass::Integer => expr_type.might_be_integer()
 
        };
 

	
 
        if is_ok {
 
            Ok(())
 
        } else {
 
            Err(self.construct_type_class_error(ctx, expr_id, type_class))
 
        }
 
    }
 

	
 
@@ -768,10 +851,16 @@ impl TypeResolvingVisitor {
 

	
 
        return ParseError2::new_error(
 
            &ctx.module.source, expr.position(),
 
            "Incompatible types: this expression expected a '{}'"
 
            &format!(
 
                "Incompatible types: this expression expected a '{}'", 
 
                expr_type.display_name(&ctx.heap)
 
            )
 
        ).with_postfixed_info(
 
            &ctx.module.source, arg_expr.position(),
 
            "But this expression yields a '{}'"
 
            &format!(
 
                "But this expression yields a '{}'",
 
                arg_type.display_name(&ctx.heap)
 
            )
 
        )
 
    }
 

	
 
@@ -779,6 +868,43 @@ impl TypeResolvingVisitor {
 
        &self, ctx: &Ctx, expr_id: ExpressionId,
 
        arg1_id: ExpressionId, arg2_id: ExpressionId
 
    ) -> ParseError2 {
 
        // TODO: Expand and provide more meaningful information for humans
 
        let expr = &ctx.heap[expr_id];
 
        let arg1 = &ctx.heap[arg1_id];
 
        let arg2 = &ctx.heap[arg2_id];
 

	
 
        let arg1_type = self.expr_types.get(&arg1_id).unwrap();
 
        let arg2_type = self.expr_types.get(&arg2_id).unwrap();
 

	
 
        return ParseError2::new_error(
 
            &ctx.module.source, expr.position(),
 
            "Incompatible types: cannot apply this expression"
 
        ).with_postfixed_info(
 
            &ctx.module.source, arg1.position(),
 
            &format!(
 
                "Because this expression has type '{}'",
 
                arg1_type.display_name(&ctx.heap)
 
            )
 
        ).with_postfixed_info(
 
            &ctx.module.source, arg2.position(),
 
            &format!(
 
                "But this expression has type '{}'",
 
                arg2_type.display_name(&ctx.heap)
 
            )
 
        )
 
    }
 

	
 
    fn construct_type_class_error(
 
        &self, ctx: &Ctx, expr_id: ExpressionId, type_class: TypeClass
 
    ) -> ParseError2 {
 
        let expr = &ctx.heap[expr_id];
 
        let expr_type = self.expr_types.get(&expr_id).unwrap();
 

	
 
        return ParseError2::new_error(
 
            &ctx.module.source, expr.position(),
 
            &format!(
 
                "Incompatible types: got a '{}' but expected a {} type",
 
                expr_type.display_name(&ctx.heap), type_class
 
            )
 
        )
 
    }
 
}
 
\ No newline at end of file
0 comments (0 inline, 0 general)