From a82fb2b1f7f93301656ef7456306305b6cc6da80 2021-03-14 18:23:45 From: MH Date: 2021-03-14 18:23:45 Subject: [PATCH] WIP on type inference --- diff --git a/src/protocol/ast.rs b/src/protocol/ast.rs index ec20aca8f6f50894e1be4b80bd4fbf0aaeb059c8..8bd519561abcbe43e9ae72b072cb647cf56d877b 100644 --- a/src/protocol/ast.rs +++ b/src/protocol/ast.rs @@ -1021,6 +1021,8 @@ impl SyntaxElement for Variable { } } +/// TODO: Remove distinction between parameter/local and add an enum to indicate +/// the distinction between the two #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct Parameter { pub this: ParameterId, diff --git a/src/protocol/parser/type_resolver.rs b/src/protocol/parser/type_resolver.rs index dc0c5f8750508c06179fa5630d05cee6ac0d73f7..2aa2171b423cffa42edc3abcc43ebe62cbf61740 100644 --- a/src/protocol/parser/type_resolver.rs +++ b/src/protocol/parser/type_resolver.rs @@ -1,5 +1,7 @@ use crate::protocol::ast::*; -use super::type_table::{ConcreteType, ConcreteTypeVariant}; +use crate::protocol::inputsource::*; +use super::type_table::*; +use super::symbol_table::*; use super::visitor::{ STMT_BUFFER_INIT_CAPACITY, EXPR_BUFFER_INIT_CAPACITY, @@ -9,11 +11,65 @@ use super::visitor::{ }; use std::collections::HashMap; -enum ExprType { - Regular, // expression statement or return statement - Memory, // memory statement's expression - Condition, // if/while conditional statement - Assert, // assert statement +pub(crate) enum InferredPart { + // Unknown section of inferred type, yet to be inferred + Unknown, + // No subtypes + Message, + Bool, + Byte, + Short, + Int, + Long, + String, + // One subtype + Array, + Slice, + Input, + Output, + // One or more subtypes + Instance(DefinitionId, usize), +} + +impl From for InferredPart { + fn from(v: ConcreteTypeVariant) -> Self { + use ConcreteTypeVariant as CTV; + use InferredPart as IP; + + match v { + CTV::Message => IP::Message, + CTV::Bool => IP::Bool, + CTV::Byte => IP::Byte, + CTV::Short => IP::Short, + CTV::Int => IP::Int, + CTV::Long => IP::Long, + CTV::String => IP::String, + CTV::Array => IP::Array, + CTV::Slice => IP::Slice, + CTV::Input => IP::Input, + CTV::Output => IP::Output, + CTV::Instance(definition_id, num_sub) => IP::Instance(definition_id, num_sub), + } + } +} + +pub(crate) struct InferenceType { + origin: ParserTypeId, + inferred: Vec, +} + +impl InferenceType { + fn new(inferred_type: ParserTypeId) -> Self { + Self{ origin: inferred_type, inferred: vec![InferredPart::Unknown] } + } + + fn assign_concrete(&mut self, concrete_type: &ConcreteType) { + self.inferred.clear(); + self.inferred.reserve(concrete_type.v.len()); + for variant in concrete_type.v { + self.inferred.push(InferredPart::from(variant)) + } + } } // TODO: @cleanup I will do a very dirty implementation first, because I have no idea @@ -37,30 +93,35 @@ enum ExprType { /// This will be achieved by slowly descending into the AST. At any given /// expression we may depend on pub(crate) struct TypeResolvingVisitor { - // Tracking traversal state - expr_type: ExprType, - // Buffers for iteration over substatements and subexpressions stmt_buffer: Vec, expr_buffer: Vec, - // Map for associating "auto"/"polyarg" variables with a concrete type where - // it is not yet determined. - env: HashMap + // If instantiating a monomorph of a polymorphic proctype, then we store the + // values of the polymorphic values here. + polyvars: Vec<(Identifier, ConcreteTypeVariant)>, + // Mapping from parser type to inferred type. We attempt to continue to + // specify these types until we're stuck or we've fully determined the type. + infer_types: HashMap, + // Mapping from variable ID to parser type, optionally inferred, so then + var_types: HashMap, } impl TypeResolvingVisitor { pub(crate) fn new() -> Self { TypeResolvingVisitor{ - expr_type: ExprType::Regular, stmt_buffer: Vec::with_capacity(STMT_BUFFER_INIT_CAPACITY), expr_buffer: Vec::with_capacity(EXPR_BUFFER_INIT_CAPACITY), - env: HashMap::new(), + infer_types: HashMap::new(), + var_types: HashMap::new(), } } fn reset(&mut self) { - self.expr_type = ExprType::Regular; + self.stmt_buffer.clear(); + self.expr_buffer.clear(); + self.infer_types.clear(); + self.var_types.clear(); } } @@ -68,13 +129,27 @@ impl Visitor2 for TypeResolvingVisitor { // Definitions fn visit_component_definition(&mut self, ctx: &mut Ctx, id: ComponentId) -> VisitorResult { + self.reset(); + let comp_def = &ctx.heap[id]; + for param_id in comp_def.parameters.clone() { + let param = &ctx.heap[param_id]; + self.var_types.insert(param_id.upcast(), param.parser_type); + } + let body_stmt_id = ctx.heap[id].body; self.visit_stmt(ctx, body_stmt_id) } fn visit_function_definition(&mut self, ctx: &mut Ctx, id: FunctionId) -> VisitorResult { + self.reset(); + let func_def = &ctx.heap[id]; + for param_id in func_def.parameters.clone() { + let param = &ctx.heap[param_id]; + self.var_types.insert(param_id.upcast(), param.parser_type); + } let body_stmt_id = ctx.heap[id].body; + self.visit_stmt(ctx, body_stmt_id) } @@ -84,25 +159,17 @@ impl Visitor2 for TypeResolvingVisitor { // Transfer statements for traversal let block = &ctx.heap[id]; - let old_len_stmts = self.stmt_buffer.len(); - self.stmt_buffer.extend(&block.statements); - let new_len_stmts = self.stmt_buffer.len(); - - // Traverse statements - for stmt_idx in old_len_stmts..new_len_stmts { - let stmt_id = self.stmt_buffer[stmt_idx]; - self.expr_type = ExprType::Regular; - self.visit_stmt(ctx, stmt_id)?; + for stmt_id in block.statements.clone() { + self.visit_stmt(ctx, stmt_id); } - self.stmt_buffer.truncate(old_len_stmts); Ok(()) } fn visit_local_memory_stmt(&mut self, ctx: &mut Ctx, id: MemoryStatementId) -> VisitorResult { let memory_stmt = &ctx.heap[id]; - - + let local = &ctx.heap[memory_stmt.variable]; + self.var_types.insert(memory_stmt.variable, ) Ok(()) } @@ -113,5 +180,111 @@ impl Visitor2 for TypeResolvingVisitor { } impl TypeResolvingVisitor { + /// Checks if the `ParserType` contains any inferred variables. If so then + /// they will be inserted into the `infer_types` variable. Here we assume + /// we're parsing the body of a proctype, so any reference to polymorphic + /// variables must refer to the polymorphic arguments of the proctype's + /// definition. + /// TODO: @cleanup: The symbol_table -> type_table pattern appears quite + /// a lot, will likely need to create some kind of function for this + fn insert_parser_type_if_needs_inference( + &mut self, ctx: &mut Ctx, root_id: RootId, parser_type_id: ParserTypeId + ) -> Result<(), ParseError2> { + use ParserTypeVariant as PTV; + + let mut to_consider = vec![parser_type_id]; + while !to_consider.is_empty() { + let parser_type_id = to_consider.pop().unwrap(); + let parser_type = &mut ctx.heap[parser_type_id]; + + match &mut parser_type.variant { + PTV::Inferred => { + self.env.insert(parser_type_id, InferenceType::new(parser_type_id)); + }, + PTV::Array(subtype_id) => { to_consider.push(*subtype_id); }, + PTV::Input(subtype_id) => { to_consider.push(*subtype_id); }, + PTV::Output(subtype_id) => { to_consider.push(*subtype_id); }, + PTV::Symbolic(symbolic) => { + // If not yet resolved, try to resolve + if symbolic.variant.is_none() { + let mut found = false; + for (poly_idx, (poly_var, _)) in self.polyvars.iter().enumerate() { + if symbolic.identifier.value == poly_var.value { + // Found a match + symbolic.variant = Some(SymbolicParserTypeVariant::PolyArg(poly_idx)) + found = true; + break; + } + } + if !found { + // Attempt to find in symbol/type table + let symbol = ctx.symbols.resolve_namespaced_symbol(root_id, &symbolic.identifier); + if symbol.is_none() { + let module_source = &ctx.module.source; + return Err(ParseError2::new_error( + module_source, symbolic.identifier.position, + "Could not resolve symbol to a type" + )); + } + + // Check if symbol was fully resolved + let (symbol, ident_iter) = symbol.unwrap(); + if ident_iter.num_remaining() != 0 { + let module_source = &ctx.module.source; + ident_iter. + return Err(ParseError2::new_error( + module_source, symbolic.identifier.position, + "Could not resolve symbol to a type" + ).with_postfixed_info( + module_source, symbol.position, + "Could resolve part of the identifier to this symbol" + )); + } + + // Check if symbol resolves to struct/enum + let definition_id = match symbol.symbol { + Symbol::Namespace(_) => { + let module_source = &ctx.module.source; + return Err(ParseError2::new_error( + module_source, symbolic.identifier.position, + "Symbol resolved to a module instead of a type" + )); + }, + Symbol::Definition((_, definition_id)) => definition_id + }; + + // Retrieve from type table and make sure it is a + // reference to a struct/enum/union + // TODO: @types Allow function pointers + let def_type = ctx.types.get_base_definition(&definition_id); + debug_assert!(def_type.is_some(), "Expected to resolve definition ID to type definition in type table"); + let def_type = def_type.unwrap(); + + let def_type_class = def_type.definition.type_class(); + if !def_type_class.is_data_type() { + return Err(ParseError2::new_error( + &ctx.module.source, symbolic.identifier.position, + &format!("Symbol refers to a {}, only data types are supported", def_type_class) + )); + } + + // Now that we're certain it is a datatype, make + // sure that the number of polyargs in the symbolic + // type matches that of the definition, or conclude + // that all polyargs need to be inferred. + if symbolic.poly_args.len() != def_type.poly_args.len() { + if symbolic.poly_args.is_empty() { + // Modify ParserType to have auto-inferred + // polymorphic arguments + symbolic.poly_args. + } + } + } + } + }, + _ => {} // Builtin, doesn't require inference + } + } + } } \ No newline at end of file diff --git a/src/protocol/parser/type_table.rs b/src/protocol/parser/type_table.rs index 1941f8da2361f7f1c25b4fea9d22a1def81a13d4..f26598128f29e2ea9ae212ee0f08edc8bad3e0df 100644 --- a/src/protocol/parser/type_table.rs +++ b/src/protocol/parser/type_table.rs @@ -82,6 +82,14 @@ impl TypeClass { TypeClass::Component => "component", } } + + pub(crate) fn is_data_type(&self) -> bool { + self == TypeClass::Enum || self == TypeClass::Union || self == TypeClass::Struct + } + + pub(crate) fn is_proc_type(&self) -> bool { + self == TypeClass::Function || self == TypeClass::Component + } } impl std::fmt::Display for TypeClass { @@ -228,6 +236,7 @@ impl TypeIterator { } } +#[derive(Copy, Clone)] pub(crate) enum ConcreteTypeVariant { // No subtypes Message, diff --git a/src/protocol/parser/visitor.rs b/src/protocol/parser/visitor.rs index 02bfaf21b5242a657c1421a777248e2844de5e29..a4794193bcf81b3fdd9a2a30e9f94a4e9d956c33 100644 --- a/src/protocol/parser/visitor.rs +++ b/src/protocol/parser/visitor.rs @@ -11,6 +11,9 @@ pub(crate) const STMT_BUFFER_INIT_CAPACITY: usize = 256; /// Globally configured vector capacity for expression buffers in visitor /// implementations pub(crate) const EXPR_BUFFER_INIT_CAPACITY: usize = 256; +/// Globally configured vector capacity for parser type buffers in visitor +/// implementations +pub(crate) const TYPE_BUFFER_INIT_CAPACITY: usize = 128; /// General context structure that is used while traversing the AST. pub(crate) struct Ctx<'p> { @@ -235,4 +238,7 @@ pub(crate) trait Visitor2 { fn visit_constant_expr(&mut self, _ctx: &mut Ctx, _id: ConstantExpressionId) -> VisitorResult { Ok(()) } fn visit_call_expr(&mut self, _ctx: &mut Ctx, _id: CallExpressionId) -> VisitorResult { Ok(()) } fn visit_variable_expr(&mut self, _ctx: &mut Ctx, _id: VariableExpressionId) -> VisitorResult { Ok(()) } + + // Types + fn visit_parser_type(&mut self, _ctx: &mut Ctx, _id: ParserTypeId) -> VisitorResult { Ok(()) } } \ No newline at end of file diff --git a/src/protocol/parser/visitor_linker.rs b/src/protocol/parser/visitor_linker.rs index 313aac2e6264a4981fa24b84733ddcf67451cc19..32cc8bc151f81871d6ebd72780ef6861198296bf 100644 --- a/src/protocol/parser/visitor_linker.rs +++ b/src/protocol/parser/visitor_linker.rs @@ -7,6 +7,7 @@ use crate::protocol::parser::{symbol_table::*, type_table::*}; use super::visitor::{ STMT_BUFFER_INIT_CAPACITY, EXPR_BUFFER_INIT_CAPACITY, + TYPE_BUFFER_INIT_CAPACITY, Ctx, Visitor2, VisitorResult @@ -15,9 +16,16 @@ use crate::protocol::ast::ExpressionParent::ExpressionStmt; #[derive(PartialEq, Eq)] enum DefinitionType { - Primitive, - Composite, - Function + None, + Primitive(ComponentId), + Composite(ComponentId), + Function(FunctionId) +} + +impl DefinitionType { + fn is_primitive(&self) -> bool { if let Self::Primitive(_) = self { true } else { false } } + fn is_composite(&self) -> bool { if let Self::Composite(_) = self { true } else { false } } + fn is_function(&self) -> bool { if let Self::Function(_) = self { true } else { false } } } /// This particular visitor will go through the entire AST in a recursive manner @@ -68,6 +76,8 @@ pub(crate) struct ValidityAndLinkerVisitor { // Another buffer, now with expression IDs, to prevent constant cloning of // vectors while working around rust's borrowing rules expression_buffer: Vec, + // Yet another buffer, now with parser type IDs, similar to above + parser_type_buffer: Vec, // Statements to insert after the breadth pass in a single block insert_buffer: Vec<(u32, StatementId)>, } @@ -79,11 +89,12 @@ impl ValidityAndLinkerVisitor { in_while: None, cur_scope: None, expr_parent: ExpressionParent::None, - def_type: DefinitionType::Primitive, + def_type: DefinitionType::None, performing_breadth_pass: false, relative_pos_in_block: 0, statement_buffer: Vec::with_capacity(STMT_BUFFER_INIT_CAPACITY), expression_buffer: Vec::with_capacity(EXPR_BUFFER_INIT_CAPACITY), + parser_type_buffer: Vec::with_capacity(TYPE_BUFFER_INIT_CAPACITY), insert_buffer: Vec::with_capacity(32), } } @@ -93,11 +104,12 @@ impl ValidityAndLinkerVisitor { self.in_while = None; self.cur_scope = None; self.expr_parent = ExpressionParent::None; - self.def_type = DefinitionType::Primitive; + self.def_type = DefinitionType::None; self.relative_pos_in_block = 0; self.performing_breadth_pass = false; self.statement_buffer.clear(); self.expression_buffer.clear(); + self.parser_type_buffer.clear(); self.insert_buffer.clear(); } } @@ -111,13 +123,30 @@ impl Visitor2 for ValidityAndLinkerVisitor { self.reset_state(); self.def_type = match &ctx.heap[id].variant { - ComponentVariant::Primitive => DefinitionType::Primitive, - ComponentVariant::Composite => DefinitionType::Composite, + ComponentVariant::Primitive => DefinitionType::Primitive(id), + ComponentVariant::Composite => DefinitionType::Composite(id), }; self.cur_scope = Some(Scope::Definition(id.upcast())); self.expr_parent = ExpressionParent::None; - let body_id = ctx.heap[id].body; + // Visit types of parameters + debug_assert!(self.parser_type_buffer.is_empty()); + let comp_def = &ctx.heap[id]; + self.parser_type_buffer.extend( + comp_def.parameters + .iter() + .map(|id| ctx.heap[*id].parser_type) + ); + + let num_types = self.parser_type_buffer.len(); + for idx in 0..num_types { + self.visit_parser_type(ctx, self.parser_type_buffer[idx])?; + } + + self.parser_type_buffer.clear(); + + // Visit statements in component body + let body_id = ctx.heap[id].body; self.performing_breadth_pass = true; self.visit_stmt(ctx, body_id)?; self.performing_breadth_pass = false; @@ -128,10 +157,28 @@ impl Visitor2 for ValidityAndLinkerVisitor { self.reset_state(); // Set internal statement indices - self.def_type = DefinitionType::Function; + self.def_type = DefinitionType::Function(id); self.cur_scope = Some(Scope::Definition(id.upcast())); self.expr_parent = ExpressionParent::None; + // Visit types of parameters + debug_assert!(self.parser_type_buffer.is_empty()); + let func_def = &ctx.heap[id]; + self.parser_type_buffer.extend( + func_def.parameters + .iter() + .map(|id| ctx.heap[*id].parser_type) + ); + self.parser_type_buffer.push(func_def.return_type); + + let num_types = self.parser_type_buffer.len(); + for idx in 0..num_types { + self.visit_parser_type(ctx, self.parser_type_buffer[idx])?; + } + + self.parser_type_buffer.clear(); + + // Visit statements in function body let body_id = ctx.heap[id].body; self.performing_breadth_pass = true; self.visit_stmt(ctx, body_id)?; @@ -152,6 +199,10 @@ impl Visitor2 for ValidityAndLinkerVisitor { let variable_id = ctx.heap[id].variable; self.checked_local_add(ctx, self.relative_pos_in_block, variable_id)?; } else { + let variable_id = ctx.heap[id].variable; + let parser_type_id = ctx.heap[variable_id].parser_type; + self.visit_parser_type(ctx, parser_type_id); + debug_assert_eq!(self.expr_parent, ExpressionParent::None); self.expr_parent = ExpressionParent::Memory(id); self.visit_expr(ctx, ctx.heap[id].initial)?; @@ -169,6 +220,12 @@ impl Visitor2 for ValidityAndLinkerVisitor { }; self.checked_local_add(ctx, self.relative_pos_in_block, from_id)?; self.checked_local_add(ctx, self.relative_pos_in_block, to_id)?; + } else { + let chan_stmt = &ctx.heap[id]; + let from_type_id = ctx.heap[chan_stmt.from].parser_type; + let to_type_id = ctx.heap[chan_stmt.to].parser_type; + self.visit_parser_type(ctx, from_type_id)?; + self.visit_parser_type(ctx, to_type_id)?; } Ok(()) @@ -304,7 +361,7 @@ impl Visitor2 for ValidityAndLinkerVisitor { ); } - if self.def_type != DefinitionType::Primitive { + if !self.def_type.is_primitive() { return Err(ParseError2::new_error( &ctx.module.source, cur_sync_position, "Synchronous statements may only be used in primitive components" @@ -334,7 +391,7 @@ impl Visitor2 for ValidityAndLinkerVisitor { fn visit_return_stmt(&mut self, ctx: &mut Ctx, id: ReturnStatementId) -> VisitorResult { if self.performing_breadth_pass { let stmt = &ctx.heap[id]; - if self.def_type != DefinitionType::Function { + if !self.def_type.is_function() { return Err( ParseError2::new_error(&ctx.module.source, stmt.position, "Return statements may only appear in function bodies") ); @@ -353,7 +410,7 @@ impl Visitor2 for ValidityAndLinkerVisitor { fn visit_assert_stmt(&mut self, ctx: &mut Ctx, id: AssertStatementId) -> VisitorResult { let stmt = &ctx.heap[id]; if self.performing_breadth_pass { - if self.def_type == DefinitionType::Function { + if self.def_type.is_function() { // TODO: We probably want to allow this. Mark the function as // using asserts, and then only allow calls to these functions // within components. Such a marker will cascade through any @@ -415,7 +472,7 @@ impl Visitor2 for ValidityAndLinkerVisitor { // TODO: Cleanup error messages, can be done cleaner // Make sure new statement occurs within a composite component let call_expr_id = ctx.heap[id].expression; - if self.def_type != DefinitionType::Composite { + if !self.def_type.is_composite() { let new_stmt = &ctx.heap[id]; return Err( ParseError2::new_error(&ctx.module.source, new_stmt.position, "Instantiating components may only be done in composite components") @@ -712,7 +769,7 @@ impl Visitor2 for ValidityAndLinkerVisitor { match &mut call_expr.method { Method::Create => {}, Method::Fires => { - if self.def_type != DefinitionType::Primitive { + if !self.def_type.is_primitive() { return Err(ParseError2::new_error( &ctx.module.source, call_expr.position, "A call to 'fires' may only occur in primitive component definitions" @@ -720,7 +777,7 @@ impl Visitor2 for ValidityAndLinkerVisitor { } }, Method::Get => { - if self.def_type != DefinitionType::Primitive { + if !self.def_type.is_primitive() { return Err(ParseError2::new_error( &ctx.module.source, call_expr.position, "A call to 'get' may only occur in primitive component definitions" @@ -788,6 +845,64 @@ impl Visitor2 for ValidityAndLinkerVisitor { Ok(()) } + + //-------------------------------------------------------------------------- + // ParserType visitors + //-------------------------------------------------------------------------- + + fn visit_parser_type(&mut self, ctx: &mut Ctx, id: ParserTypeId) -> VisitorResult { + // We visit a particular type rooted in a non-ParserType node in the + // AST. Within this function we set up a buffer to visit all nested + // ParserType nodes. + // The goal is to link symbolic ParserType instances to the appropriate + // definition or symbolic type. Alternatively to throw an error if we + // cannot resolve the ParserType to either of these (polymorphic) types. + use ParserTypeVariant as PTV; + debug_assert!(!self.performing_breadth_pass); + + let init_num_types = self.parser_type_buffer.len(); + self.parser_type_buffer.push(id); + + 'resolve_loop: while self.parser_type_buffer.len() > init_num_types { + let parser_type_id = self.parser_type_buffer.pop().unwrap(); + let parser_type = &ctx.heap[parser_type_id]; + + match &parser_type.variant { + PTV::Message | PTV::Bool | + PTV::Byte | PTV::Short | PTV::Int | PTV::Long | + PTV::String | + PTV::IntegerLiteral | PTV::Inferred => { + // Builtin types or types that do not require recursion + continue 'resolve_loop; + }, + PTV::Array(subtype_id) | + PTV::Input(subtype_id) | + PTV::Output(subtype_id) => { + // Requires recursing + self.parser_type_buffer.push(*subtype_id); + continue 'resolve_loop; + }, + PTV::Symbolic(symbolic) => { + // Retrieve poly_vars from function/component definition to + // match against. + let poly_vars = match self.def_type { + DefinitionType::None => unreachable!(), + DefinitionType::Primitive(id) => &ctx.heap[id].poly_vars, + DefinitionType::Composite(id) => &ctx.heap[id].poly_vars, + DefinitionType::Function(id) => &ctx.heap[id].poly_vars, + }; + + for (poly_var_idx, poly_var) in poly_vars.iter().enumerate() { + if symbolic.identifier.value == poly_var.value { + + } + } + } + } + } + + Ok(()) + } } enum FindOfTypeResult {