diff --git a/src/protocol/ast.rs b/src/protocol/ast.rs index 2d81d1aabb71209306ea0e326894ffd1af6378e8..05b1e5b41e43b621eb400e0024c9fb4b98126645 100644 --- a/src/protocol/ast.rs +++ b/src/protocol/ast.rs @@ -110,7 +110,6 @@ define_new_ast_id!(ReturnStatementId, StatementId, ReturnStatement, Statement::R define_new_ast_id!(AssertStatementId, StatementId, AssertStatement, Statement::Assert, statements); define_new_ast_id!(GotoStatementId, StatementId, GotoStatement, Statement::Goto, statements); define_new_ast_id!(NewStatementId, StatementId, NewStatement, Statement::New, statements); -define_new_ast_id!(PutStatementId, StatementId, PutStatement, Statement::Put, statements); define_new_ast_id!(ExpressionStatementId, StatementId, ExpressionStatement, Statement::Expression, statements); define_aliased_ast_id!(ExpressionId, Id, Expression, expressions); @@ -421,14 +420,6 @@ impl Heap { self.statements.alloc_with_id(|id| Statement::New(f(NewStatementId(id)))), ) } - pub fn alloc_put_statement( - &mut self, - f: impl FnOnce(PutStatementId) -> PutStatement, - ) -> PutStatementId { - PutStatementId( - self.statements.alloc_with_id(|id| Statement::Put(f(PutStatementId(id)))), - ) - } pub fn alloc_labeled_statement( &mut self, f: impl FnOnce(LabeledStatementId) -> LabeledStatement, @@ -908,6 +899,7 @@ pub enum Constant { #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub enum Method { Get, + Put, Fires, Create, Symbolic(MethodSymbolic) @@ -1271,7 +1263,6 @@ pub enum Statement { Assert(AssertStatement), Goto(GotoStatement), New(NewStatement), - Put(PutStatement), Expression(ExpressionStatement), } @@ -1432,12 +1423,6 @@ impl Statement { _ => panic!("Unable to cast `Statement` to `NewStatement`"), } } - pub fn as_put(&self) -> &PutStatement { - match self { - Statement::Put(result) => result, - _ => panic!("Unable to cast `Statement` to `PutStatement`"), - } - } pub fn as_expression(&self) -> &ExpressionStatement { match self { Statement::Expression(result) => result, @@ -1457,7 +1442,6 @@ impl Statement { Statement::EndSynchronous(stmt) => stmt.next = Some(next), Statement::Assert(stmt) => stmt.next = Some(next), Statement::New(stmt) => stmt.next = Some(next), - Statement::Put(stmt) => stmt.next = Some(next), Statement::Expression(stmt) => stmt.next = Some(next), Statement::Return(_) | Statement::Break(_) @@ -1490,7 +1474,6 @@ impl SyntaxElement for Statement { Statement::Assert(stmt) => stmt.position(), Statement::Goto(stmt) => stmt.position(), Statement::New(stmt) => stmt.position(), - Statement::Put(stmt) => stmt.position(), Statement::Expression(stmt) => stmt.position(), } } @@ -1883,23 +1866,6 @@ impl SyntaxElement for NewStatement { } } -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct PutStatement { - pub this: PutStatementId, - // Phase 1: parser - pub position: InputPosition, - pub port: ExpressionId, - pub message: ExpressionId, - // Phase 2: linker - pub next: Option, -} - -impl SyntaxElement for PutStatement { - fn position(&self) -> InputPosition { - self.position - } -} - #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] pub struct ExpressionStatement { pub this: ExpressionStatementId, @@ -1925,7 +1891,6 @@ pub enum ExpressionParent { Return(ReturnStatementId), Assert(AssertStatementId), New(NewStatementId), - Put(PutStatementId, u32), // index of arg ExpressionStmt(ExpressionStatementId), Expression(ExpressionId, u32) // index within expression (e.g LHS or RHS of expression) } @@ -2040,6 +2005,13 @@ impl Expression { Expression::Variable(expr) => &expr.parent, } } + pub fn parent_expr_id(&self) -> Option { + if let ExpressionParent::Expression(id, _) = self.parent() { + Some(*id) + } else { + None + } + } pub fn set_parent(&mut self, parent: ExpressionParent) { match self { Expression::Assignment(expr) => expr.parent = parent, diff --git a/src/protocol/ast_printer.rs b/src/protocol/ast_printer.rs index 1949225c4f0c02d67d52867ddfe143f1ed0c8bd0..fa75fcb31a7dad2546a25ddac755c6a8b7a08c54 100644 --- a/src/protocol/ast_printer.rs +++ b/src/protocol/ast_printer.rs @@ -485,16 +485,6 @@ impl ASTWriter { self.kv(indent2).with_s_key("Next") .with_opt_disp_val(stmt.next.as_ref().map(|v| &v.index)); }, - Statement::Put(stmt) => { - self.kv(indent).with_id(PREFIX_PUT_STMT_ID, stmt.this.0.index) - .with_s_key("Put"); - self.kv(indent2).with_s_key("Port"); - self.write_expr(heap, stmt.port, indent3); - self.kv(indent2).with_s_key("Message"); - self.write_expr(heap, stmt.message, indent3); - self.kv(indent2).with_s_key("Next") - .with_opt_disp_val(stmt.next.as_ref().map(|v| &v.index)); - }, Statement::Expression(stmt) => { self.kv(indent).with_id(PREFIX_EXPR_STMT_ID, stmt.this.0.index) .with_s_key("ExpressionStatement"); @@ -628,6 +618,7 @@ impl ASTWriter { let method = self.kv(indent2).with_s_key("Method"); match &expr.method { Method::Get => { method.with_s_val("get"); }, + Method::Put => { method.with_s_val("put"); }, Method::Fires => { method.with_s_val("fires"); }, Method::Create => { method.with_s_val("create"); }, Method::Symbolic(symbolic) => { @@ -749,7 +740,6 @@ fn write_expression_parent(target: &mut String, parent: &ExpressionParent) { EP::Return(id) => format!("ReturnStmt({})", id.0.index), EP::Assert(id) => format!("AssertStmt({})", id.0.index), EP::New(id) => format!("NewStmt({})", id.0.index), - EP::Put(id, idx) => format!("PutStmt({}, {})", id.0.index, idx), EP::ExpressionStmt(id) => format!("ExprStmt({})", id.0.index), EP::Expression(id, idx) => format!("Expr({}, {})", id.index, idx) }; diff --git a/src/protocol/eval.rs b/src/protocol/eval.rs index e8e3c57d5d6e35c67d88aca09a9b2dfc626e7737..a16a05fa3183f31abe804a6d8ef8ba06dec471d7 100644 --- a/src/protocol/eval.rs +++ b/src/protocol/eval.rs @@ -1499,10 +1499,24 @@ impl Store { } Expression::Constant(expr) => Ok(Value::from_constant(&expr.value)), Expression::Call(expr) => match &expr.method { - Method::Create => { + Method::Get => { assert_eq!(1, expr.arguments.len()); - let length = self.eval(h, ctx, expr.arguments[0])?; - Ok(Value::create_message(length)) + let value = self.eval(h, ctx, expr.arguments[0])?; + match ctx.get(value.clone()) { + None => Err(EvalContinuation::BlockGet(value)), + Some(result) => Ok(result), + } + } + Method::Put => { + assert_eq!(2, expr.arguments.len()); + let port_value = self.eval(h, ctx, expr.arguments[0])?; + let msg_value = self.eval(h, ctx, expr.arguments[1])?; + if ctx.did_put(port_value.clone()) { + // Return bogus, replacing this at some point anyway + Ok(Value::Message(MessageValue(None))) + } else { + Err(EvalContinuation::Put(port_value, msg_value)) + } } Method::Fires => { assert_eq!(1, expr.arguments.len()); @@ -1512,13 +1526,10 @@ impl Store { Some(result) => Ok(result), } } - Method::Get => { + Method::Create => { assert_eq!(1, expr.arguments.len()); - let value = self.eval(h, ctx, expr.arguments[0])?; - match ctx.get(value.clone()) { - None => Err(EvalContinuation::BlockGet(value)), - Some(result) => Ok(result), - } + let length = self.eval(h, ctx, expr.arguments[0])?; + Ok(Value::create_message(length)) } Method::Symbolic(_symbol) => unimplemented!(), }, @@ -1696,15 +1707,6 @@ impl Prompt { _ => unreachable!("not a symbolic call expression") } } - Statement::Put(stmt) => { - // Evaluate port and message - let port = self.store.eval(h, ctx, stmt.port)?; - let message = self.store.eval(h, ctx, stmt.message)?; - // Continue to next statement - self.position = stmt.next; - // Signal the put upwards - Err(EvalContinuation::Put(port, message)) - } Statement::Expression(stmt) => { // Evaluate expression let _value = self.store.eval(h, ctx, stmt.expression)?; diff --git a/src/protocol/lexer.rs b/src/protocol/lexer.rs index e87532ab007d89541ba5dcc0394d4f9a14a95d89..2cf9f3b581cac741ce858461f489a9f2dce167d8 100644 --- a/src/protocol/lexer.rs +++ b/src/protocol/lexer.rs @@ -1414,7 +1414,7 @@ impl Lexer<'_> { if self.consume_namespaced_identifier_spilled().is_ok() && self.consume_whitespace(false).is_ok() && - self.maybe_consume_poly_args_spilled_without_pos_recovery().is_ok() && + self.maybe_consume_poly_args_spilled_without_pos_recovery() && self.consume_whitespace(false).is_ok() && self.source.next() == Some(b'(') { // Seems like we have a function call or an enum literal @@ -1432,6 +1432,9 @@ impl Lexer<'_> { if self.has_keyword(b"get") { self.consume_keyword(b"get")?; method = Method::Get; + } else if self.has_keyword(b"put") { + self.consume_keyword(b"put")?; + method = Method::Put; } else if self.has_keyword(b"fires") { self.consume_keyword(b"fires")?; method = Method::Fires; @@ -1553,8 +1556,6 @@ impl Lexer<'_> { self.consume_goto_statement(h)?.upcast() } else if self.has_keyword(b"new") { self.consume_new_statement(h)?.upcast() - } else if self.has_keyword(b"put") { - self.consume_put_statement(h)?.upcast() } else if self.has_label() { self.consume_labeled_statement(h)?.upcast() } else { @@ -1899,22 +1900,6 @@ impl Lexer<'_> { self.consume_string(b";")?; Ok(h.alloc_new_statement(|this| NewStatement { this, position, expression, next: None })) } - fn consume_put_statement(&mut self, h: &mut Heap) -> Result { - let position = self.source.pos(); - self.consume_keyword(b"put")?; - self.consume_whitespace(false)?; - self.consume_string(b"(")?; - let port = self.consume_expression(h)?; - self.consume_whitespace(false)?; - self.consume_string(b",")?; - self.consume_whitespace(false)?; - let message = self.consume_expression(h)?; - self.consume_whitespace(false)?; - self.consume_string(b")")?; - self.consume_whitespace(false)?; - self.consume_string(b";")?; - Ok(h.alloc_put_statement(|this| PutStatement { this, position, port, message, next: None })) - } fn consume_expression_statement( &mut self, h: &mut Heap, diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 9ea716fb06b139f56996a305106369224b576835..3352aa183902a1d0ca745f59a08e780944a27e04 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -294,4 +294,16 @@ impl EvalContext<'_> { }, } } + fn did_put(&mut self, port: Value) -> bool { + match self { + EvalContext::Nonsync(_) => unreachable!("did_put in nonsync context"), + EvalContext::Sync(context) => match port { + Value::Output(OutputValue(port)) => { + context.is_firing(port).unwrap_or(false) + }, + Value::Input(_) => unreachable!("did_put on input port"), + _ => unreachable!("did_put on non-port value") + } + } + } } diff --git a/src/protocol/parser/depth_visitor.rs b/src/protocol/parser/depth_visitor.rs index 0dfa83adc790cb5ca77def228b9686f717e8d5c0..da957edf9eb4611fd23a050d6d1f30287e99a8c9 100644 --- a/src/protocol/parser/depth_visitor.rs +++ b/src/protocol/parser/depth_visitor.rs @@ -118,9 +118,6 @@ pub(crate) trait Visitor: Sized { fn visit_new_statement(&mut self, h: &mut Heap, stmt: NewStatementId) -> VisitorResult { recursive_new_statement(self, h, stmt) } - fn visit_put_statement(&mut self, h: &mut Heap, stmt: PutStatementId) -> VisitorResult { - recursive_put_statement(self, h, stmt) - } fn visit_expression_statement( &mut self, h: &mut Heap, @@ -320,7 +317,6 @@ fn recursive_statement(this: &mut T, h: &mut Heap, stmt: StatementId Statement::Assert(stmt) => this.visit_assert_statement(h, stmt.this), Statement::Goto(stmt) => this.visit_goto_statement(h, stmt.this), Statement::New(stmt) => this.visit_new_statement(h, stmt.this), - Statement::Put(stmt) => this.visit_put_statement(h, stmt.this), Statement::Expression(stmt) => this.visit_expression_statement(h, stmt.this), Statement::EndSynchronous(stmt) => this.visit_end_synchronous_statement(h, stmt.this), Statement::EndWhile(stmt) => this.visit_end_while_statement(h, stmt.this), @@ -424,15 +420,6 @@ fn recursive_new_statement( recursive_call_expression_as_expression(this, h, h[stmt].expression) } -fn recursive_put_statement( - this: &mut T, - h: &mut Heap, - stmt: PutStatementId, -) -> VisitorResult { - this.visit_expression(h, h[stmt].port)?; - this.visit_expression(h, h[stmt].message) -} - fn recursive_expression_statement( this: &mut T, h: &mut Heap, @@ -969,10 +956,6 @@ impl Visitor for LinkStatements { self.prev = Some(UniqueStatementId(stmt.upcast())); Ok(()) } - fn visit_put_statement(&mut self, _h: &mut Heap, stmt: PutStatementId) -> VisitorResult { - self.prev = Some(UniqueStatementId(stmt.upcast())); - Ok(()) - } fn visit_expression_statement( &mut self, _h: &mut Heap, diff --git a/src/protocol/parser/type_resolver.rs b/src/protocol/parser/type_resolver.rs index e4b71644faa72e33c00a9b956dde1093e3153dc8..8b7ea100689138c4099b292f21e0b39bca2e1ed8 100644 --- a/src/protocol/parser/type_resolver.rs +++ b/src/protocol/parser/type_resolver.rs @@ -1,3 +1,10 @@ +/// type_resolver.rs +/// +/// Performs type inference and type checking +/// +/// TODO: Needs an optimization pass +/// TODO: Needs a cleanup pass + use std::collections::{HashMap, HashSet, VecDeque}; use crate::protocol::ast::*; @@ -12,7 +19,9 @@ use super::visitor::{ VisitorResult }; use std::collections::hash_map::Entry; +use crate::protocol::parser::type_resolver::InferenceTypePart::IntegerLike; +const MESSAGE_TEMPLATE: [InferenceTypePart; 1] = [ InferenceTypePart::Message ]; const BOOL_TEMPLATE: [InferenceTypePart; 1] = [ InferenceTypePart::Bool ]; const NUMBERLIKE_TEMPLATE: [InferenceTypePart; 1] = [ InferenceTypePart::NumberLike ]; const INTEGERLIKE_TEMPLATE: [InferenceTypePart; 1] = [ InferenceTypePart::IntegerLike ]; @@ -181,6 +190,13 @@ impl InferenceType { Self{ has_marker, is_done, parts } } + fn replace_subtree(&mut self, start_idx: usize, with: &[InferenceTypePart]) { + let end_idx = Self::find_subtree_end_idx(&self.parts, start_idx); + debug_assert_eq!(with.len(), Self::find_subtree_end_idx(with, 0)); + self.parts.splice(start_idx..end_idx, with.iter().cloned()); + self.recompute_is_done(); + } + // TODO: @performance, might all be done inline in the type inference methods fn recompute_is_done(&mut self) { self.is_done = self.parts.iter().all(|v| v.is_concrete()); @@ -229,10 +245,15 @@ impl InferenceType { } } + /// Returns an iterator over all markers and the partial type tree that + /// follows those markers. fn marker_iter(&self) -> InferenceTypeMarkerIter { InferenceTypeMarkerIter::new(&self.parts) } + /// Attempts to find a marker with a specific value appearing at or after + /// the specified index. If found then the partial type tree's bounding + /// indices that follow that marker are returned. fn find_subtree_idx_for_marker(&self, marker: usize, mut idx: usize) -> Option<(usize, usize)> { // Seek ahead to find a marker let marker = InferenceTypePart::Marker(marker); @@ -700,9 +721,12 @@ pub(crate) struct TypeResolvingVisitor { polyvars: Vec, // Mapping from parser type to inferred type. We attempt to continue to // specify these types until we're stuck or we've fully determined the type. - infer_types: HashMap, - expr_types: HashMap, - extra_data: HashMap, + var_types: HashMap, // types of variables + expr_types: HashMap, // types of expressions + extra_data: HashMap, // data for function call inference + + // Keeping track of which expressions need to be reinferred because the + // expressions they're linked to made progression on an associated type expr_queued: HashSet, } @@ -715,6 +739,11 @@ struct ExtraData { returned: InferenceType, } +struct VarData { + var_type: InferenceType, + used_at: Vec, +} + impl TypeResolvingVisitor { pub(crate) fn new() -> Self { TypeResolvingVisitor{ @@ -722,7 +751,7 @@ impl TypeResolvingVisitor { stmt_buffer: Vec::with_capacity(STMT_BUFFER_INIT_CAPACITY), expr_buffer: Vec::with_capacity(EXPR_BUFFER_INIT_CAPACITY), polyvars: Vec::new(), - infer_types: HashMap::new(), + var_types: HashMap::new(), expr_types: HashMap::new(), extra_data: HashMap::new(), expr_queued: HashSet::new(), @@ -734,7 +763,7 @@ impl TypeResolvingVisitor { self.stmt_buffer.clear(); self.expr_buffer.clear(); self.polyvars.clear(); - self.infer_types.clear(); + self.var_types.clear(); self.expr_types.clear(); } } @@ -751,9 +780,9 @@ impl Visitor2 for TypeResolvingVisitor { for param_id in comp_def.parameters.clone() { let param = &ctx.heap[param_id]; - let infer_type = self.determine_inference_type_from_parser_type(ctx, param.parser_type, true); - debug_assert!(infer_type.is_done, "expected component arguments to be concrete types"); - self.infer_types.insert(param_id.upcast(), infer_type); + let var_type = self.determine_inference_type_from_parser_type(ctx, param.parser_type, true); + debug_assert!(var_type.is_done, "expected component arguments to be concrete types"); + self.var_types.insert(param_id.upcast(), VarData{ var_type, used_at: Vec::new() }); } let body_stmt_id = ctx.heap[id].body; @@ -769,9 +798,9 @@ impl Visitor2 for TypeResolvingVisitor { for param_id in func_def.parameters.clone() { let param = &ctx.heap[param_id]; - let infer_type = self.determine_inference_type_from_parser_type(ctx, param.parser_type, true); - debug_assert!(infer_type.is_done, "expected function arguments to be concrete types"); - self.infer_types.insert(param_id.upcast(), infer_type); + let var_type = self.determine_inference_type_from_parser_type(ctx, param.parser_type, true); + debug_assert!(var_type.is_done, "expected function arguments to be concrete types"); + self.var_types.insert(param_id.upcast(), VarData{ var_type, used_at: Vec::new() }); } let body_stmt_id = ctx.heap[id].body; @@ -795,8 +824,8 @@ impl Visitor2 for TypeResolvingVisitor { let memory_stmt = &ctx.heap[id]; let local = &ctx.heap[memory_stmt.variable]; - let infer_type = self.determine_inference_type_from_parser_type(ctx, local.parser_type, true); - self.infer_types.insert(memory_stmt.variable.upcast(), infer_type); + let var_type = self.determine_inference_type_from_parser_type(ctx, local.parser_type, true); + self.var_types.insert(memory_stmt.variable.upcast(), VarData{ var_type, used_at: Vec::new() }); let expr_id = memory_stmt.initial; self.visit_expr(ctx, expr_id)?; @@ -808,12 +837,12 @@ impl Visitor2 for TypeResolvingVisitor { let channel_stmt = &ctx.heap[id]; let from_local = &ctx.heap[channel_stmt.from]; - let from_infer_type = self.determine_inference_type_from_parser_type(ctx, from_local.parser_type, true); - self.infer_types.insert(from_local.this.upcast(), from_infer_type); + let from_var_type = self.determine_inference_type_from_parser_type(ctx, from_local.parser_type, true); + self.var_types.insert(from_local.this.upcast(), VarData{ var_type: from_var_type, used_at: Vec::new() }); let to_local = &ctx.heap[channel_stmt.to]; - let to_infer_type = self.determine_inference_type_from_parser_type(ctx, to_local.parser_type, true); - self.infer_types.insert(to_local.this.upcast(), to_infer_type); + let to_var_type = self.determine_inference_type_from_parser_type(ctx, to_local.parser_type, true); + self.var_types.insert(to_local.this.upcast(), VarData{ var_type: to_var_type, used_at: Vec::new() }); Ok(()) } @@ -878,19 +907,6 @@ impl Visitor2 for TypeResolvingVisitor { self.visit_call_expr(ctx, call_expr_id) } - fn visit_put_stmt(&mut self, ctx: &mut Ctx, id: PutStatementId) -> VisitorResult { - let put_stmt = &ctx.heap[id]; - - let port_expr_id = put_stmt.port; - let msg_expr_id = put_stmt.message; - // TODO: What what? - - self.visit_expr(ctx, port_expr_id)?; - self.visit_expr(ctx, msg_expr_id)?; - - Ok(()) - } - fn visit_expr_stmt(&mut self, ctx: &mut Ctx, id: ExpressionStatementId) -> VisitorResult { let expr_stmt = &ctx.heap[id]; let subexpr_id = expr_stmt.expression; @@ -952,11 +968,72 @@ impl Visitor2 for TypeResolvingVisitor { let unary_expr = &ctx.heap[id]; let arg_expr_id = unary_expr.expression; - self.visit_expr(ctx, arg_expr_id); + self.visit_expr(ctx, arg_expr_id)?; self.progress_unary_expr(ctx, id) } + fn visit_indexing_expr(&mut self, ctx: &mut Ctx, id: IndexingExpressionId) -> VisitorResult { + let upcast_id = id.upcast(); + self.insert_initial_expr_inference_type(ctx, upcast_id)?; + + let indexing_expr = &ctx.heap[id]; + let subject_expr_id = indexing_expr.subject; + let index_expr_id = indexing_expr.index; + + self.visit_expr(ctx, subject_expr_id)?; + self.visit_expr(ctx, index_expr_id)?; + + self.progress_indexing_expr(ctx, id) + } + + fn visit_slicing_expr(&mut self, ctx: &mut Ctx, id: SlicingExpressionId) -> VisitorResult { + let upcast_id = id.upcast(); + self.insert_initial_expr_inference_type(ctx, upcast_id)?; + + let slicing_expr = &ctx.heap[id]; + let subject_expr_id = slicing_expr.subject; + let from_expr_id = slicing_expr.from_index; + let to_expr_id = slicing_expr.to_index; + + self.visit_expr(ctx, subject_expr_id)?; + self.visit_expr(ctx, from_expr_id)?; + self.visit_expr(ctx, to_expr_id)?; + + self.progress_slicing_expr(ctx, id) + } + + fn visit_select_expr(&mut self, ctx: &mut Ctx, id: SelectExpressionId) -> VisitorResult { + let upcast_id = id.upcast(); + self.insert_initial_expr_inference_type(ctx, upcast_id)?; + + let select_expr = &ctx.heap[id]; + let subject_expr_id = select_expr.subject; + + self.visit_expr(ctx, subject_expr_id)?; + + self.progress_select_expr(ctx, id) + } + + fn visit_array_expr(&mut self, ctx: &mut Ctx, id: ArrayExpressionId) -> VisitorResult { + let upcast_id = id.upcast(); + self.insert_initial_expr_inference_type(ctx, upcast_id)?; + + let array_expr = &ctx.heap[id]; + // TODO: @performance + for element_id in array_expr.elements.clone().into_iter() { + self.visit_expr(ctx, element_id)?; + } + + self.progress_array_expr(ctx, id) + } + + fn visit_constant_expr(&mut self, ctx: &mut Ctx, id: ConstantExpressionId) -> VisitorResult { + let upcast_id = id.upcast(); + self.insert_initial_expr_inference_type(ctx, upcast_id)?; + self.progress_constant_expr(ctx, id) + } + fn visit_call_expr(&mut self, ctx: &mut Ctx, id: CallExpressionId) -> VisitorResult { let upcast_id = id.upcast(); self.insert_initial_expr_inference_type(ctx, upcast_id)?; @@ -970,6 +1047,18 @@ impl Visitor2 for TypeResolvingVisitor { self.progress_call_expr(ctx, id) } + + fn visit_variable_expr(&mut self, ctx: &mut Ctx, id: VariableExpressionId) -> VisitorResult { + let upcast_id = id.upcast(); + self.insert_initial_expr_inference_type(ctx, upcast_id)?; + + let var_expr = &ctx.heap[id]; + debug_assert!(var_expr.declaration.is_some()); + let var_data = self.var_types.get_mut(var_expr.declaration.as_ref().unwrap()).unwrap(); + var_data.used_at.push(upcast_id); + + self.progress_variable_expr(ctx, id) + } } macro_rules! debug_assert_expr_ids_unique_and_known { @@ -1209,6 +1298,79 @@ impl TypeResolvingVisitor { Ok(()) } + fn progress_select_expr(&mut self, ctx: &mut Ctx, id: SelectExpressionId) -> Result<(), ParseError2> { + let upcast_id = id.upcast(); + let expr = &ctx.heap[id]; + let subject_id = expr.subject; + + let (progress_subject, progress_expr) = match &expr.field { + Field::Length => { + let progress_subject = self.apply_forced_constraint(ctx, subject_id, &ARRAYLIKE_TEMPLATE)?; + let progress_expr = self.apply_forced_constraint(ctx, upcast_id, &INTEGERLIKE_TEMPLATE)?; + (progress_subject, progress_expr) + }, + Field::Symbolic(_field) => { + todo!("implement select expr for symbolic fields"); + } + }; + + if progress_subject { self.queue_expr(subject_id); } + if progress_expr { self.queue_expr_parent(ctx, upcast_id); } + + Ok(()) + } + + fn progress_array_expr(&mut self, ctx: &mut Ctx, id: ArrayExpressionId) -> Result<(), ParseError2> { + let upcast_id = id.upcast(); + let expr = &ctx.heap[id]; + let expr_elements = expr.elements.clone(); // TODO: @performance + + // All elements should have an equal type + let progress = self.apply_equal_n_constraint(ctx, upcast_id, &expr_elements)?; + let mut any_progress = false; + for (progress_arg, arg_id) in progress.iter().zip(expr_elements.iter()) { + if *progress_arg { + any_progress = true; + self.queue_expr(*arg_id); + } + } + + // And the output should be an array of the element types + let mut expr_progress = self.apply_forced_constraint(ctx, upcast_id, &ARRAY_TEMPLATE)?; + if !expr_elements.is_empty() { + let first_arg_id = expr_elements[0]; + let (inner_expr_progress, arg_progress) = self.apply_equal2_constraint( + ctx, upcast_id, upcast_id, 1, first_arg_id, 0 + )?; + + expr_progress = expr_progress || inner_expr_progress; + + // Note that if the array type progressed the type of the arguments, + // then we should enqueue this progression function again + if arg_progress { self.queue_expr(upcast_id); } + } + + if expr_progress { self.queue_expr_parent(ctx, upcast_id); } + + Ok(()) + } + + fn progress_constant_expr(&mut self, ctx: &mut Ctx, id: ConstantExpressionId) -> Result<(), ParseError2> { + let upcast_id = id.upcast(); + let expr = &ctx.heap[id]; + let template = match &expr.value { + Constant::Null => &MESSAGE_TEMPLATE, + Constant::Integer(_) => &INTEGERLIKE_TEMPLATE, + Constant::True | Constant::False => &BOOL_TEMPLATE, + Constant::Character(_) => todo!("character literals") + }; + + let progress = self.apply_forced_constraint(ctx, upcast_id, template)?; + if progress { self.queue_expr_parent(ctx, upcast_id); } + + Ok(()) + } + // TODO: @cleanup, see how this can be cleaned up once I implement // polymorphic struct/enum/union literals. These likely follow the same // pattern as here. @@ -1247,7 +1409,7 @@ impl TypeResolvingVisitor { } if progress_arg { // Progressed argument expression - self.queue_expr(arg_id); + self.expr_queued.insert(arg_id); } } @@ -1273,7 +1435,9 @@ impl TypeResolvingVisitor { } } if progress_expr { - self.queue_expr_parent(ctx, upcast_id); + if let Some(parent_id) = ctx.heap[upcast_id].parent_expr_id() { + self.expr_queued.insert(parent_id); + } } // If we had an error in the polymorphic variable's inference, then we @@ -1292,12 +1456,12 @@ impl TypeResolvingVisitor { // For each polymorphic argument: first extend the signature type, // then reapply the equal2 constraint to the expressions let poly_type = &extra.poly_vars[poly_idx]; - for (arg_idx, arg_type) in extra.embedded.iter_mut().enumerate() { + for (arg_idx, sig_type) in extra.embedded.iter_mut().enumerate() { let mut seek_idx = 0; let mut modified_sig = false; - while let Some((start_idx, end_idx)) = arg_type.find_subtree_idx_for_marker(poly_idx, seek_idx) { + while let Some((start_idx, end_idx)) = sig_type.find_subtree_idx_for_marker(poly_idx, seek_idx) { let modified_at_marker = Self::apply_forced_constraint_types( - arg_type, start_idx, &poly_type.parts, 0 + sig_type, start_idx, &poly_type.parts, 0 ).unwrap(); modified_sig = modified_sig || modified_at_marker; seek_idx = end_idx; @@ -1308,10 +1472,81 @@ impl TypeResolvingVisitor { // Part of signature was modified, so update expression used as // argument as well let arg_expr_id = expr.arguments[arg_idx]; - let arg_type = self.expr_types.get_mut(arg_expr_id).unwrap(); - Self::apply_equal2_constraint_types(ctx, arg_expr_id, ) + let arg_type: *mut _ = self.expr_types.get_mut(&arg_expr_id).unwrap(); + let (progress_arg, _) = Self::apply_equal2_constraint_types( + ctx, arg_expr_id, arg_type, 0, sig_type, 0 + ).expect("no inference error at argument type"); + if progress_arg { self.expr_queued.insert(arg_expr_id); } + } + + // Again: do the same for the return type + let sig_type = &mut extra.returned; + let mut seek_idx = 0; + let mut modified_sig = false; + while let Some((start_idx, end_idx)) = sig_type.find_subtree_idx_for_marker(poly_idx, seek_idx) { + let modified_at_marker = Self::apply_forced_constraint_types( + sig_type, start_idx, &poly_type.parts, 0 + ).unwrap(); + modified_sig = modified_sig || modified_at_marker; + seek_idx = end_idx; + } + + if modified_sig { + let ret_type = self.expr_types.get_mut(&upcast_id).unwrap(); + let (progress_ret, _) = Self::apply_equal2_constraint_types( + ctx, upcast_id, ret_type, 0, sig_type, 0 + ).expect("no inference error at return type"); + if progress_ret { + if let Some(parent_id) = ctx.heap[upcast_id].parent_expr_id() { + self.expr_queued.insert(parent_id); + } + } + } + } + + Ok(()) + } + + fn progress_variable_expr(&mut self, ctx: &mut Ctx, id: VariableExpressionId) -> Result<(), ParseError2> { + let upcast_id = id.upcast(); + let var_expr = &ctx.heap[id]; + let var_id = var_expr.declaration.unwrap(); + + // Retrieve shared variable type and expression type and apply inference + let var_data = self.var_types.get_mut(&var_id).unwrap(); + let expr_type = self.expr_types.get_mut(&upcast_id).unwrap(); + + let infer_res = unsafe{ InferenceType::infer_subtrees_for_both_types( + &mut var_data.var_type as *mut _, 0, expr_type, 0 + ) }; + if infer_res == DualInferenceResult::Incompatible { + let var_decl = &ctx.heap[var_id]; + return Err(ParseError2::new_error( + &ctx.module.source, var_decl.position(), + &format!( + "Conflicting types for this variable, previously assigned the type '{}'", + var_data.var_type.display_name(&ctx.heap) + ) + ).with_postfixed_info( + &ctx.module.source, var_expr.position, + &format!( + "But inferred to have incompatible type '{}' here", + expr_type.display_name(&ctx.heap) + ) + )) + } + + let progress_var = infer_res.modified_lhs(); + let progress_expr = infer_res.modified_rhs(); + + if progress_var { + for other_expr in var_data.used_at.iter() { + if *other_expr != upcast_id { + self.expr_queued.insert(*other_expr); + } } } + if progress_expr { self.queue_expr_parent(ctx, upcast_id); } Ok(()) } @@ -1442,8 +1677,9 @@ impl TypeResolvingVisitor { if args_res.modified_lhs() { unsafe { - (*expr_type).parts.drain(start_idx..); - (*expr_type).parts.extend_from_slice(&((*arg2_type).parts[start_idx..])); + let end_idx = InferenceType::find_subtree_end_idx(&(*arg2_type).parts, start_idx); + let subtree = &((*arg2_type).parts[start_idx..end_idx]); + (*expr_type).replace_subtree(start_idx, subtree); } progress_expr = true; progress_arg1 = true; @@ -1452,6 +1688,66 @@ impl TypeResolvingVisitor { Ok((progress_expr, progress_arg1, progress_arg2)) } + // TODO: @optimize Since we only deal with a single type this might be done + // a lot more efficiently, methinks (disregarding the allocations here) + fn apply_equal_n_constraint( + &mut self, ctx: &Ctx, expr_id: ExpressionId, args: &[ExpressionId], + ) -> Result, ParseError2> { + // Early exit + match args.len() { + 0 => return Ok(vec!()), // nothing to progress + 1 => return Ok(vec![false]), // only one type, so nothing to infer + _ => {} + } + + let mut progress = Vec::new(); + progress.resize(args.len(), false); + + // Do pairwise inference, keep track of the last entry we made progress + // on. Once done we need to update everything to the most-inferred type. + let mut arg_iter = args.iter(); + let mut last_arg_id = *arg_iter.next().unwrap(); + let mut last_lhs_progressed = 0; + let mut lhs_arg_idx = 0; + + while let Some(next_arg_id) = arg_iter.next() { + let arg1_type: *mut _ = self.expr_types.get_mut(&last_arg_id).unwrap(); + let arg2_type: *mut _ = self.expr_types.get_mut(next_arg_id).unwrap(); + + let res = unsafe { + InferenceType::infer_subtrees_for_both_types(arg1_type, 0, arg2_type, 0) + }; + + if res == DualInferenceResult::Incompatible { + return Err(self.construct_arg_type_error(ctx, expr_id, last_arg_id, *next_arg_id)); + } + + if res.modified_lhs() { + // We re-inferred something on the left hand side, so everything + // up until now should be re-inferred. + progress[lhs_arg_idx] = true; + last_lhs_progressed = lhs_arg_idx; + } + progress[lhs_arg_idx + 1] = res.modified_rhs(); + + last_arg_id = *next_arg_id; + lhs_arg_idx += 1; + } + + // Re-infer everything. Note that we do not need to re-infer the type + // exactly at `last_lhs_progressed`, but only everything up to it. + let last_type: *mut _ = self.expr_types.get_mut(args.last().unwrap()).unwrap(); + for arg_idx in 0..last_lhs_progressed { + let arg_type: *mut _ = self.expr_types.get_mut(&args[arg_idx]).unwrap(); + unsafe{ + (*arg_type).replace_subtree(0, &(*last_type).parts); + } + progress[arg_idx] = true; + } + + Ok(progress) + } + /// Determines the `InferenceType` for the expression based on the /// expression parent. Note that if the parent is another expression, we do /// not take special action, instead we let parent expressions fix the type @@ -1489,16 +1785,6 @@ impl TypeResolvingVisitor { // Must be a component call, which we assign a "Void" return // type InferenceType::new(false, true, vec![ITP::Void]), - EP::Put(_, 0) => - // TODO: Change put to be a builtin function - // port of "put" call - InferenceType::new(false, false, vec![ITP::Output, ITP::Unknown]), - EP::Put(_, 1) => - // TODO: Change put to be a builtin function - // message of "put" call - InferenceType::new(false, true, vec![ITP::Message]), - EP::Put(_, _) => - unreachable!() }; match self.expr_types.entry(expr_id) { @@ -1568,6 +1854,16 @@ impl TypeResolvingVisitor { InferenceType::new(true, false, vec![ITP::Marker(0), ITP::Unknown]) ) }, + Method::Put => { + // void Put(output port, T msg) + ( + vec![ + InferenceType::new(true, false, vec![ITP::Output, ITP::Marker(0), ITP::Unknown]), + InferenceType::new(true, false, vec![ITP::Marker(0), ITP::Unknown]) + ], + InferenceType::new(false, true, vec![ITP::Void]) + ) + } Method::Symbolic(symbolic) => { let definition = &ctx.heap[symbolic.definition.unwrap()]; @@ -1663,10 +1959,11 @@ impl TypeResolvingVisitor { }, PTV::Symbolic(symbolic) => { debug_assert!(symbolic.variant.is_some(), "symbolic variant not yet determined"); - match symbolic.variant.unwrap() { + match symbolic.variant.as_ref().unwrap() { SymbolicParserTypeVariant::PolyArg(_, arg_idx) => { // Retrieve concrete type of argument and add it to // the inference type. + let arg_idx = *arg_idx; debug_assert!(symbolic.poly_args.is_empty()); // TODO: @hkt if parser_type_in_body { @@ -1683,7 +1980,7 @@ impl TypeResolvingVisitor { SymbolicParserTypeVariant::Definition(definition_id) => { // TODO: @cleanup if cfg!(debug_assertions) { - let definition = &ctx.heap[definition_id]; + let definition = &ctx.heap[*definition_id]; debug_assert!(definition.is_struct() || definition.is_enum()); // TODO: @function_ptrs let num_poly = match definition { Definition::Struct(v) => v.poly_vars.len(), @@ -1693,7 +1990,7 @@ impl TypeResolvingVisitor { debug_assert_eq!(symbolic.poly_args.len(), num_poly); } - infer_type.push(ITP::Instance(definition_id, symbolic.poly_args.len())); + infer_type.push(ITP::Instance(*definition_id, symbolic.poly_args.len())); let mut poly_arg_idx = symbolic.poly_args.len(); while poly_arg_idx > 0 { poly_arg_idx -= 1; @@ -1823,6 +2120,7 @@ impl TypeResolvingVisitor { Method::Create => unreachable!(), Method::Fires => (String::from('T'), String::from("fires")), Method::Get => (String::from('T'), String::from("get")), + Method::Put => (String::from('T'), String::from("put")), Method::Symbolic(symbolic) => { let definition = &ctx.heap[symbolic.definition.unwrap()]; let poly_var = match definition { diff --git a/src/protocol/parser/type_table.rs b/src/protocol/parser/type_table.rs index 6a1acd0e24e14d773634aa79191664fb8b0c7433..2ce2d93f242b7245f5a6f75e898407ab370e4600 100644 --- a/src/protocol/parser/type_table.rs +++ b/src/protocol/parser/type_table.rs @@ -181,8 +181,8 @@ pub struct StructField { } pub struct FunctionType { - return_type: ParserTypeId, - arguments: Vec + pub return_type: ParserTypeId, + pub arguments: Vec } pub struct ComponentType { @@ -641,7 +641,8 @@ impl TypeTable { // Construct polymorphic arguments let mut poly_args = self.create_initial_poly_args(&definition.poly_vars); - self.check_and_resolve_embedded_type_and_modify_poly_args(ctx, definition_id, &mut poly_args, root_id, definition.return_type)?; + let return_type_id = definition.return_type; + self.check_and_resolve_embedded_type_and_modify_poly_args(ctx, definition_id, &mut poly_args, root_id, return_type_id)?; for argument in &arguments { self.check_and_resolve_embedded_type_and_modify_poly_args(ctx, definition_id, &mut poly_args, root_id, argument.parser_type)?; } diff --git a/src/protocol/parser/visitor.rs b/src/protocol/parser/visitor.rs index a4794193bcf81b3fdd9a2a30e9f94a4e9d956c33..8332e8e681e073aa53dd5525d7f554c74ccf1d70 100644 --- a/src/protocol/parser/visitor.rs +++ b/src/protocol/parser/visitor.rs @@ -134,10 +134,6 @@ pub(crate) trait Visitor2 { let this = stmt.this; self.visit_new_stmt(ctx, this) }, - Statement::Put(stmt) => { - let this = stmt.this; - self.visit_put_stmt(ctx, this) - }, Statement::Expression(stmt) => { let this = stmt.this; self.visit_expr_stmt(ctx, this) @@ -173,7 +169,6 @@ pub(crate) trait Visitor2 { fn visit_assert_stmt(&mut self, _ctx: &mut Ctx, _id: AssertStatementId) -> VisitorResult { Ok(()) } fn visit_goto_stmt(&mut self, _ctx: &mut Ctx, _id: GotoStatementId) -> VisitorResult { Ok(()) } fn visit_new_stmt(&mut self, _ctx: &mut Ctx, _id: NewStatementId) -> VisitorResult { Ok(()) } - fn visit_put_stmt(&mut self, _ctx: &mut Ctx, _id: PutStatementId) -> VisitorResult { Ok(()) } fn visit_expr_stmt(&mut self, _ctx: &mut Ctx, _id: ExpressionStatementId) -> VisitorResult { Ok(()) } // Expressions diff --git a/src/protocol/parser/visitor_linker.rs b/src/protocol/parser/visitor_linker.rs index b2d93e0200c3b3298de4a287ecba2cf5af5e1ced..c013cab65a594ac86aa9d90cb7eecc120dd0691b 100644 --- a/src/protocol/parser/visitor_linker.rs +++ b/src/protocol/parser/visitor_linker.rs @@ -561,32 +561,6 @@ impl Visitor2 for ValidityAndLinkerVisitor { Ok(()) } - fn visit_put_stmt(&mut self, ctx: &mut Ctx, id: PutStatementId) -> VisitorResult { - // TODO: Make `put` an expression. Perhaps silly, but much easier to - // perform typechecking - if self.performing_breadth_pass { - let put_stmt = &ctx.heap[id]; - if self.in_sync.is_none() { - return Err(ParseError2::new_error( - &ctx.module.source, put_stmt.position, "Put must be called in a synchronous block" - )); - } - } else { - let put_stmt = &ctx.heap[id]; - let port = put_stmt.port; - let message = put_stmt.message; - - debug_assert_eq!(self.expr_parent, ExpressionParent::None); - self.expr_parent = ExpressionParent::Put(id, 0); - self.visit_expr(ctx, port)?; - self.expr_parent = ExpressionParent::Put(id, 1); - self.visit_expr(ctx, message)?; - self.expr_parent = ExpressionParent::None; - } - - Ok(()) - } - fn visit_expr_stmt(&mut self, ctx: &mut Ctx, id: ExpressionStatementId) -> VisitorResult { if !self.performing_breadth_pass { let expr_id = ctx.heap[id].expression; @@ -779,15 +753,16 @@ impl Visitor2 for ValidityAndLinkerVisitor { debug_assert!(!self.performing_breadth_pass); let call_expr = &mut ctx.heap[id]; + let num_expr_args = call_expr.arguments.len(); // Resolve the method to the appropriate definition and check the // legality of the particular method call. // TODO: @cleanup Unify in some kind of signature call, see similar // cleanup comments with this `match` format. - let num_args; + let num_definition_args; match &mut call_expr.method { Method::Create => { - num_args = 1; + num_definition_args = 1; }, Method::Fires => { if !self.def_type.is_primitive() { @@ -796,7 +771,13 @@ impl Visitor2 for ValidityAndLinkerVisitor { "A call to 'fires' may only occur in primitive component definitions" )); } - num_args = 1; + if self.in_sync.is_none() { + return Err(ParseError2::new_error( + &ctx.module.source, call_expr.position, + "A call to 'fires' may only occur inside synchronous blocks" + )); + } + num_definition_args = 1; }, Method::Get => { if !self.def_type.is_primitive() { @@ -805,8 +786,29 @@ impl Visitor2 for ValidityAndLinkerVisitor { "A call to 'get' may only occur in primitive component definitions" )); } - num_args = 1; + if self.in_sync.is_none() { + return Err(ParseError2::new_error( + &ctx.module.source, call_expr.position, + "A call to 'get' may only occur inside synchronous blocks" + )); + } + num_definition_args = 1; }, + Method::Put => { + if !self.def_type.is_primitive() { + return Err(ParseError2::new_error( + &ctx.module.source, call_expr.position, + "A call to 'put' may only occur in primitive component definitions" + )); + } + if self.in_sync.is_none() { + return Err(ParseError2::new_error( + &ctx.module.source, call_expr.position, + "A call to 'put' may only occur inside synchronous blocks" + )); + } + num_definition_args = 2; + } Method::Symbolic(symbolic) => { // Find symbolic method let found_symbol = self.find_symbol_of_type( @@ -830,9 +832,9 @@ impl Visitor2 for ValidityAndLinkerVisitor { }; symbolic.definition = Some(definition_id); - match ctx.types.get_base_definition(&definition_id).unwrap() { - Definition::Function(definition) => { - num_args = definition.parameters.len(); + match &ctx.types.get_base_definition(&definition_id).unwrap().definition { + DefinedTypeVariant::Function(definition) => { + num_definition_args = definition.arguments.len(); }, _ => unreachable!(), } @@ -842,18 +844,18 @@ impl Visitor2 for ValidityAndLinkerVisitor { // Check the poly args and the number of variables in the call // expression self.visit_call_poly_args(ctx, id)?; - if call_expr.arguments.len() != num_args { + let call_expr = &mut ctx.heap[id]; + if num_expr_args != num_definition_args { return Err(ParseError2::new_error( &ctx.module.source, call_expr.position, &format!( "This call expects {} arguments, but {} were provided", - num_args, call_expr.arguments.len() + num_definition_args, num_expr_args ) )); } // Recurse into all of the arguments and set the expression's parent - let call_expr = &mut ctx.heap[id]; let upcast_id = id.upcast(); let old_num_exprs = self.expression_buffer.len(); @@ -1478,6 +1480,9 @@ impl ValidityAndLinkerVisitor { Method::Get => { 1 }, + Method::Put => { + 1 + } Method::Symbolic(symbolic) => { let definition = &ctx.heap[symbolic.definition.unwrap()]; if let Definition::Function(definition) = definition {