Changeset - d673a04baa22
[Not reviewed]
0 1 0
MH - 4 years ago 2021-03-17 15:49:28
contact@maxhenger.nl
WIP on 'simple' type inference
1 file changed with 346 insertions and 17 deletions:
0 comments (0 inline, 0 general)
src/protocol/parser/type_resolver.rs
Show inline comments
 
use std::{collections::{HashMap, VecDeque}, fmt::Display};
 
use std::collections::{HashMap, HashSet, VecDeque};
 

	
 
use crate::protocol::ast::*;
 
use crate::protocol::inputsource::*;
 
@@ -19,6 +19,7 @@ pub(crate) enum InferredPart {
 
    // Special cases
 
    Void, // For builtin functions without a return type. Result of a "call to a component"
 
    IntegerLike, // For integer literals without a concrete type
 
    ArrayLike, // Array or Slice
 
    // No subtypes
 
    Message,
 
    Bool,
 
@@ -44,6 +45,14 @@ impl InferredPart {
 
            _ => false
 
        }
 
    }
 

	
 
    fn is_concrete_arraylike(&self) -> bool {
 
        use InferredPart as IP;
 
        match self {
 
            IP::Slice | IP::Array => true,
 
            _ => false
 
        }
 
    }
 
}
 

	
 
impl From<ConcreteTypeVariant> for InferredPart {
 
@@ -73,6 +82,11 @@ pub(crate) struct InferenceType {
 
    parts: Vec<InferredPart>,
 
}
 

	
 
const BOOL_TEMPLATE: InferenceType = InferenceType{ done: true, parts: vec![InferredPart::Bool] };
 
const INTEGERLIKE_TEMPLATE: InferenceType = InferenceType{ done: false, parts: vec![InferredPart::IntegerLike] };
 
const ARRAYLIKE_TEMPLATE: InferenceType = InferenceType{ done: false, parts: vec![InferredPart::ArrayLike, InferredPart::Unknown] };
 
const SLICE_TEMPLATE: InferenceType = InferenceType{ done: false, parts: vec![InferredPart::Slice, InferredPart::Unknown] };
 

	
 
impl InferenceType {
 
    fn new(done: bool, inferred: Vec<InferredPart>) -> Self {
 
        Self{ done, parts: inferred }
 
@@ -90,7 +104,7 @@ impl InferenceType {
 
                IP::String => {
 
                    depth -= 1;
 
                },
 
                IP::Array | IP::Slice | IP::Input | IP::Output => {
 
                IP::ArrayLike | IP::Array | IP::Slice | IP::Input | IP::Output => {
 
                    // depth remains unaltered
 
                },
 
                IP::Instance(_, num_sub) => {
 
@@ -122,6 +136,17 @@ impl InferenceType {
 
        self.might_be_numeric()
 
    }
 

	
 
    fn might_be_boolean(&self) -> bool {
 
        use InferredPart as IP;
 
        debug_assert!(!self.parts.is_empty());
 
        if self.parts.len() != 1 { return false; }
 
        match self.parts[0] {
 
            IP::Unknown | IP::Bool => true,
 
            _ =>
 
                false
 
        }
 
    }
 

	
 
    fn display_name(&self, heap: &Heap) -> String {
 
        use InferredPart as IP;
 

	
 
@@ -137,6 +162,11 @@ impl InferenceType {
 
                IP::Int => v.push_str("int"),
 
                IP::Long => v.push_str("long"),
 
                IP::String => v.push_str("str"),
 
                IP::ArrayLike => {
 
                    *idx += 1;
 
                    write_recursive(v, t, h, idx);
 
                    v.push_str("[?]");
 
                }
 
                IP::Array => {
 
                    *idx += 1;
 
                    write_recursive(v, t, h, idx);
 
@@ -207,6 +237,18 @@ impl InferenceResult {
 
            _ => false
 
        }
 
    }
 
    fn modified_rhs(&self) -> bool {
 
        match self {
 
            InferenceResult::Second | InferenceResult::Both => true,
 
            _ => false
 
        }
 
    }
 
}
 

	
 
fn progress_inference_part_and_advance_idx(
 
    a: &mut InferenceType, b: &mut InferenceType, iter_idx: &mut usize
 
) -> Result<(), InferenceResult> {
 

	
 
}
 

	
 
// Attempts to infer types within two `InferenceType` instances. If they are
 
@@ -241,13 +283,16 @@ unsafe fn progress_inference_types(a: *mut InferenceType, b: *mut InferenceType)
 

	
 
        // Not an exact match, so deal with edge cases
 
        // - inference of integerlike types to conrete integers
 
        if *a_part == InferredPart::IntegerLike && b_part.is_concrete_int() {
 
        // - inference of arraylike to array/slice
 
        if (*a_part == InferredPart::IntegerLike && b_part.is_concrete_int()) ||
 
            (*a_part == InferredPart::ArrayLike && b_part.is_concrete_arraylike()) {
 
            modified_a = true;
 
            a.parts[iter_idx] = b_part.clone();
 
            iter_idx += 1;
 
            continue;
 
        }
 
        if *b_part == InferredPart::IntegerLike && a_part.is_concrete_int() {
 
        if (*b_part == InferredPart::IntegerLike && a_part.is_concrete_int()) ||
 
            (*b_part == InferredPart::ArrayLike && a_part.is_concrete_arraylike()){
 
            modified_b = true;
 
            b.parts[iter_idx] = a_part.clone();
 
            iter_idx += 1;
 
@@ -283,7 +328,7 @@ unsafe fn progress_inference_types(a: *mut InferenceType, b: *mut InferenceType)
 
        return InferenceResult::Incompatible;
 
    }
 

	
 
    // TODO: @performance, can be done inline
 
    // TODO: @performance, can be done in the loop above
 
    a.done = true;
 
    for part in &a.parts {
 
        if *part == InferredPart::Unknown || *part == InferredPart::IntegerLike {
 
@@ -312,6 +357,69 @@ unsafe fn progress_inference_types(a: *mut InferenceType, b: *mut InferenceType)
 
    }
 
}
 

	
 
enum InferenceTemplateResult {
 
    Unmodified,
 
    Modified,
 
    Incompatible,
 
}
 

	
 
fn progress_inference_with_template(t: &mut InferenceType, template: &InferenceType) -> InferenceTemplateResult {
 
    debug_assert!(!t.parts.is_empty());
 
    debug_assert!(!template.parts.is_empty());
 

	
 
    // Iterate over the elements of both types, enforce the template onto the
 
    // provided mutable inference type.
 
    let mut modified = false;
 
    let mut iter_idx = 0;
 

	
 
    while iter_idx < t.parts.len() {
 
        let part = &t.parts[iter_idx];
 
        let template_part = &template.parts[iter_idx];
 

	
 
        if part == template_part {
 
            iter_idx += 1;
 
            continue
 
        }
 

	
 
        // Check if we can infer anything from the template
 
        if (*part == InferredPart::IntegerLike && template_part.is_concrete_int()) ||
 
            (*part == InferredPart::ArrayLike && template_part.is_concrete_arraylike()) {
 
            modified = true;
 
            t.parts[iter_idx] = template_part.clone();
 
            iter_idx += 1;
 
            continue;
 
        }
 

	
 
        if *part == InferredPart::Unknown {
 
            let end_idx = template.find_subtree_end_idx(iter_idx);
 
            t.parts[iter_idx] = template.parts[iter_idx].clone();
 
            for insert_idx in (iter_idx + 1)..end_idx {
 
                t.parts.insert(insert_idx, template.parts[insert_idx].clone());
 
            }
 

	
 
            modified = true;
 
            iter_idx = end_idx;
 
            continue;
 
        }
 

	
 
        // Because we're dealing with a template some mismatched parts are still
 
        // valid
 
        if (part.is_concrete_arraylike() && *template_part == InferredPart::ArrayLike) ||
 
            (part.is_concrete_int() && *template_part == InferredPart::IntegerLike) {
 
            iter_idx += 1;
 
            continue;
 
        }
 

	
 
        return InferenceTemplateResult::Incompatible;
 
    }
 

	
 
    if modified {
 
        InferenceTemplateResult::Modified
 
    } else {
 
        InferenceTemplateResult::Unmodified
 
    }
 
}
 

	
 
enum DefinitionType{
 
    None,
 
    Component(ComponentId),
 
@@ -353,6 +461,7 @@ pub(crate) struct TypeResolvingVisitor {
 
    // specify these types until we're stuck or we've fully determined the type.
 
    infer_types: HashMap<VariableId, InferenceType>,
 
    expr_types: HashMap<ExpressionId, InferenceType>,
 
    expr_queued: HashSet<ExpressionId>,
 
}
 

	
 
impl TypeResolvingVisitor {
 
@@ -364,6 +473,7 @@ impl TypeResolvingVisitor {
 
            polyvars: Vec::new(),
 
            infer_types: HashMap::new(),
 
            expr_types: HashMap::new(),
 
            expr_queued: HashSet::new(),
 
        }
 
    }
 

	
 
@@ -549,7 +659,7 @@ impl Visitor2 for TypeResolvingVisitor {
 
        self.visit_expr(ctx, left_expr_id)?;
 
        self.visit_expr(ctx, right_expr_id)?;
 

	
 
        // TODO: Initial progress?
 
        self.progress_assignment_expr(ctx, id)
 
    }
 

	
 
    fn visit_conditional_expr(&mut self, ctx: &mut Ctx, id: ConditionalExpressionId) -> VisitorResult {
 
@@ -566,21 +676,50 @@ impl Visitor2 for TypeResolvingVisitor {
 
        self.visit_expr(ctx, true_expr_id)?;
 
        self.visit_expr(ctx, false_expr_id)?;
 

	
 
        // TODO: Initial progress?
 
        Ok(())
 
        self.progress_conditional_expr(ctx, id)
 
    }
 

	
 
    fn visit_binary_expr(&mut self, ctx: &mut Ctx, id: BinaryExpressionId) -> VisitorResult {
 
        let upcast_id = id.upcast();
 
        self.insert_initial_expr_inference_type(ctx, upcast_id);
 

	
 
        let binary_expr = &ctx.heap[id];
 
        let lhs_expr_id = binary_expr.left;
 
        let rhs_expr_id = binary_expr.right;
 

	
 
        self.visit_expr(ctx, lhs_expr_id)?;
 
        self.visit_expr(ctx, rhs_expr_id)?;
 

	
 
        self.progress_binary_expr(ctx, id)
 
    }
 

	
 
    fn visit_unary_expr(&mut self, ctx: &mut Ctx, id: UnaryExpressionId) -> VisitorResult {
 
        let upcast_id = id.upcast();
 
        self.insert_initial_expr_inference_type(ctx, upcast_id);
 

	
 
        let unary_expr = &ctx.heap[id];
 
        let arg_expr_id = unary_expr.expression;
 

	
 
        self.visit_expr(ctx, arg_expr_id);
 

	
 
        self.progress_unary_expr(ctx, id)
 
    }
 
}
 

	
 
// TODO: @cleanup Decide to use this where appropriate or to make templates for
 
//  everything
 
enum TypeClass {
 
    Numeric, // int and float
 
    Integer,
 
    Integer, // only ints
 
    Boolean, // only boolean
 
}
 

	
 
impl std::fmt::Display for TypeClass {
 
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
 
        write!(f, "{}", match self {
 
            TypeClass::Numeric => "numeric",
 
            TypeClass::Integer => "integer"
 
            TypeClass::Integer => "integer",
 
            TypeClass::Boolean => "boolean",
 
        })
 
    }
 
}
 
@@ -614,7 +753,7 @@ enum TypeConstraintResult {
 
}
 

	
 
impl TypeResolvingVisitor {
 
    fn progress_assignment_expr(&mut self, ctx: &mut Ctx, id: AssignmentExpressionId) -> Result<bool, ParseError2> {
 
    fn progress_assignment_expr(&mut self, ctx: &mut Ctx, id: AssignmentExpressionId) -> Result<(), ParseError2> {
 
        use AssignmentOperator as AO;
 

	
 
        // TODO: Assignable check
 
@@ -634,13 +773,182 @@ impl TypeResolvingVisitor {
 
        };
 

	
 
        let upcast_id = id.upcast();
 
        let made_progress = self.apply_equal3_constraint(ctx, upcast_id, arg1_expr_id, arg2_expr_id)?;
 
        let (progress_expr, progress_arg1, progress_arg2) = self.apply_equal3_constraint(
 
            ctx, upcast_id, arg1_expr_id, arg2_expr_id
 
        )?;
 

	
 
        if let Some(type_class) = type_class {
 
            self.expr_type_is_of_type_class(ctx, id.upcast(), type_class)?
 
        }
 

	
 
        Ok(made_progress)
 
        if progress_expr { self.queue_expr_parent(ctx, upcast_id); }
 
        if progress_arg1 { self.queue_expr(arg1_expr_id); }
 
        if progress_arg2 { self.queue_expr(arg2_expr_id); }
 

	
 
        Ok(())
 
    }
 

	
 
    fn progress_conditional_expr(&mut self, ctx: &mut Ctx, id: ConditionalExpressionId) -> Result<(), ParseError2> {
 
        // Note: test expression type is already enforced
 
        let upcast_id = id.upcast();
 
        let expr = &ctx.heap[id];
 
        let arg1_expr_id = expr.true_expression;
 
        let arg2_expr_id = expr.false_expression;
 

	
 
        let (progress_expr, progress_arg1, progress_arg2) = self.apply_equal3_constraint(
 
            ctx, upcast_id, arg1_expr_id, arg2_expr_id
 
        )?;
 

	
 
        if progress_expr { self.queue_expr_parent(ctx, upcast_id); }
 
        if progress_arg1 { self.queue_expr(arg1_expr_id); }
 
        if progress_arg2 { self.queue_expr(arg2_expr_id); }
 

	
 
        Ok(())
 
    }
 

	
 
    fn progress_binary_expr(&mut self, ctx: &mut Ctx, id: BinaryExpressionId) -> Result<(), ParseError2> {
 
        // Note: our expression type might be fixed by our parent, but we still
 
        // need to make sure it matches the type associated with our operation.
 
        use BinaryOperator as BO;
 

	
 
        let upcast_id = id.upcast();
 
        let expr = &ctx.heap[id];
 
        let arg1_id = expr.left;
 
        let arg2_id = expr.right;
 

	
 
        let (progress_expr, progress_arg1, progress_arg2) = match expr.operation {
 
            BO::Concatenate => {
 
                // Equal3 with array-like appearance, two slice arguments result
 
                // in an array output
 
                todo!("implement concatenate typechecking");
 
                (false, false, false)
 
            },
 
            BO::LogicalOr | BO::LogicalAnd => {
 
                // Forced boolean on all
 
                let progress_expr = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?;
 
                let progress_arg1 = self.apply_forced_constraint(ctx, arg1_id, &BOOL_TEMPLATE)?;
 
                let progress_arg2 = self.apply_forced_constraint(ctx, arg2_id, &BOOL_TEMPLATE)?;
 

	
 
                (progress_expr, progress_arg1, progress_arg2)
 
            },
 
            BO::BitwiseOr | BO::BitwiseXor | BO::BitwiseAnd | BO::Remainder | BO::ShiftLeft | BO::ShiftRight => {
 
                let result = self.apply_equal3_constraint(ctx, upcast_id, arg1_id, arg2_id)?;
 
                self.expr_type_is_of_type_class(ctx, upcast_id, TypeClass::Integer)?;
 
                result
 
            },
 
            BO::Equality | BO::Inequality | BO::LessThan | BO::GreaterThan | BO::LessThanEqual | BO::GreaterThanEqual => {
 
                // Equal2 on args, forced boolean output
 
                let progress_expr = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?;
 
                let (progress_arg1, progress_arg2) = self.apply_equal2_constraint(ctx, upcast_id, arg1_id, arg2_id);
 
                self.expr_type_is_of_type_class(ctx, arg1_id, TypeClass::Numeric)?;
 

	
 
                (progress_expr, progress_arg1, progress_arg2)
 
            },
 
            BO::Add | BO::Subtract | BO::Multiply | BO::Divide => {
 
                let result = self.apply_equal3_constraint(ctx, upcast_id, arg1_id, arg2_id)?;
 
                self.expr_type_is_of_type_class(ctx, upcast_id, TypeClass::Numeric)?;
 
                result
 
            },
 
        };
 

	
 
        if progress_expr { self.queue_expr_parent(ctx, upcast_id); }
 
        if progress_arg1 { self.queue_expr(arg1_id); }
 
        if progress_arg2 { self.queue_expr(arg2_id); }
 

	
 
        Ok(())
 
    }
 

	
 
    fn progress_unary_expr(&mut self, ctx: &mut Ctx, id: UnaryExpressionId) -> Result<(), ParseError2> {
 
        use UnaryOperation as UO;
 

	
 
        let upcast_id = id.upcast();
 
        let expr = &ctx.heap[id];
 
        let arg_id = expr.expression;
 

	
 
        let (progress_expr, progress_arg) = match expr.operation {
 
            UO::Positive | UO::Negative => {
 
                // Equal types of numeric class
 
                let progress = self.apply_equal2_constraint(ctx, upcast_id, upcast_id, arg_id)?;
 
                self.expr_type_is_of_type_class(ctx, upcast_id, TypeClass::Numeric)?;
 
                progress
 
            },
 
            UO::BitwiseNot | UO::PreIncrement | UO::PreDecrement | UO::PostIncrement | UO::PostDecrement => {
 
                // Equal types of integer class
 
                let progress = self.apply_equal2_constraint(ctx, upcast_id, upcast_id, arg_id)?;
 
                self.expr_type_is_of_type_class(ctx, upcast_id, TypeClass::Integer)?;
 
                progress
 
            },
 
            UO::LogicalNot => {
 
                // Both booleans
 
                let progress_expr = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?;
 
                let progress_arg = self.apply_forced_constraint(ctx, upcast_id, &BOOL_TEMPLATE)?;
 
                (progress_expr, progress_arg)
 
            }
 
        };
 

	
 
        if progress_expr { self.queue_expr_parent(ctx, upcast_id); }
 
        if progress_arg { self.queue_expr(arg_id); }
 

	
 
        Ok(())
 
    }
 

	
 
    fn progress_indexing_expr(&mut self, ctx: &mut Ctx, id: IndexingExpressionId) -> Result<(), ParseError2> {
 
        // TODO: Indexable check
 
        let upcast_id = id.upcast();
 
        let expr = &ctx.heap[id];
 
        let subject_id = expr.subject;
 
        let index_id = expr.index;
 

	
 
        let progress_subject = self.apply_forced_constraint(ctx, subject_id, &ARRAYLIKE_TEMPLATE)?;
 
        let progress_index = self.apply_forced_constraint(ctx, index_id, &INTEGERLIKE_TEMPLATE)?;
 

	
 
    }
 

	
 
    fn queue_expr_parent(&mut self, ctx: &Ctx, expr_id: ExpressionId) {
 
        if let ExpressionParent::Expression(parent_expr_id, _) = &ctx.heap[expr_id].parent() {
 
            self.expr_queued.insert(*parent_expr_id)
 
        }
 
    }
 

	
 
    fn queue_expr(&mut self, expr_id: ExpressionId) {
 
        self.expr_queued.insert(expr_id);
 
    }
 

	
 
    /// Applies a forced type constraint: the type associated with the supplied
 
    /// expression will be molded into the provided "template". The template may
 
    /// be fully specified (e.g. a bool) or contain "inference" variables (e.g.
 
    /// an array of T)
 
    fn apply_forced_constraint(
 
        &mut self, ctx: &mut Ctx, expr_id: ExpressionId, template: &InferenceType
 
    ) -> Result<bool, ParseError2> {
 
        debug_assert_expr_ids_unique_and_known!(self, expr_id);
 
        let expr_type = self.expr_types.get_mut(&expr_id).unwrap();
 
        match progress_inference_with_template(expr_type, template) {
 
            InferenceTemplateResult::Modified => Ok(true),
 
            InferenceTemplateResult::Unmodified => Ok(false),
 
            InferenceTemplateResult::Incompatible => Err(
 
                self.construct_template_type_error(ctx, expr_id, template)
 
            )
 
        }
 
    }
 

	
 
    /// Applies a type constraint that expects the two provided types to be
 
    /// equal. We attempt to make progress in inferring the types. If the call
 
    /// is successful then the composition of all types are made equal.
 
    /// The "parent" `expr_id` is provided to construct errors.
 
    fn apply_equal2_constraint(
 
        &mut self, ctx: &mut Ctx, expr_id: ExpressionId, arg1_id: ExpressionId, arg2_id: ExpressionId
 
    ) -> Result<(bool, bool), ParseError2> {
 
        debug_assert_expr_ids_unique_and_known!(self, arg1_id, arg2_id);
 
        let arg1_type: *mut _ = self.expr_types.get_mut(&arg1_id).unwrap();
 
        let arg2_type: *mut _ = self.expr_types.get_mut(&arg2_id).unwrap();
 

	
 
        let infer_res = unsafe{ progress_inference_types(expr_idarg1_type, arg2_type) };
 
        if infer_res == InferenceResult::Incompatible {
 
            return Err(self.construct_arg_type_error(ctx, expr_id, arg1_id, arg2_id));
 
        }
 

	
 
        Ok((infer_res.modified_lhs(), infer_res.modified_rhs()))
 
    }
 

	
 
    /// Applies a type constraint that expects all three provided types to be
 
@@ -650,7 +958,7 @@ impl TypeResolvingVisitor {
 
    fn apply_equal3_constraint(
 
        &mut self, ctx: &mut Ctx, expr_id: ExpressionId,
 
        arg1_id: ExpressionId, arg2_id: ExpressionId
 
    ) -> Result<bool, ParseError2> {
 
    ) -> Result<(bool, bool, bool), ParseError2> {
 
        // Safety: all expression IDs are always distinct, and we do not modify
 
        //  the container
 
        debug_assert_expr_ids_unique_and_known!(self, expr_id, arg1_id, arg2_id);
 
@@ -670,15 +978,20 @@ impl TypeResolvingVisitor {
 

	
 
        // If all types are compatible, but the second call caused the arg1_type
 
        // to be expanded, then we must also assign this to expr_type.
 
        let mut progress_expr = expr_res.modified_lhs();
 
        let mut progress_arg1 = expr_res.modified_rhs();
 
        let mut progress_arg2 = args_res.modified_rhs();
 

	
 
        if args_res.modified_lhs() { 
 
            unsafe {
 
                (*expr_type).parts.clear();
 
                (*expr_type).parts.extend((*arg2_type).parts.iter());
 
            }
 
            progress_expr = true;
 
            progress_arg1 = true;
 
        }
 

	
 
        let made_progress = expr_res.modified_any() || args_res.modified_any();
 
        Ok(made_progress)
 
        Ok((progress_expr, progress_arg1, progress_arg2))
 
    }
 

	
 
    /// Applies a typeclass constraint: checks if the type is of a particular
 
@@ -691,7 +1004,8 @@ impl TypeResolvingVisitor {
 

	
 
        let is_ok = match type_class {
 
            TypeClass::Numeric => expr_type.might_be_numeric(),
 
            TypeClass::Integer => expr_type.might_be_integer()
 
            TypeClass::Integer => expr_type.might_be_integer(),
 
            TypeClass::Boolean => expr_type.might_be_boolean(),
 
        };
 

	
 
        if is_ok {
 
@@ -907,4 +1221,19 @@ impl TypeResolvingVisitor {
 
            )
 
        )
 
    }
 

	
 
    fn construct_template_type_error(
 
        &self, ctx: &Ctx, expr_id: ExpressionId, template: &InferenceType
 
    ) -> ParseError2 {
 
        let expr = &ctx.heap[expr_id];
 
        let expr_type = self.expr_types.get(&expr_id).unwrap();
 

	
 
        return ParseError2::new_error(
 
            &ctx.module.source, expr.position(),
 
            &format!(
 
                "Incompatible types: got a '{}' but expected a '{}'",
 
                expr_type.display_name(&ctx.heap), template.display_name(&ctx.heap)
 
            )
 
        )
 
    }
 
}
 
\ No newline at end of file
0 comments (0 inline, 0 general)