diff --git a/src/protocol/parser/pass_typing.rs b/src/protocol/parser/pass_typing.rs index 6f40f02bbff94ca6d26ba694fb95e1c6d2e0958b..aa9bdd04553257001f9173bea47edcf384bc7d10 100644 --- a/src/protocol/parser/pass_typing.rs +++ b/src/protocol/parser/pass_typing.rs @@ -810,6 +810,7 @@ enum SingleInferenceResult { type InferNodeIndex = usize; type PolyDataIndex = usize; +type VarDataIndex = usize; enum DefinitionType{ Component(ComponentDefinitionId), @@ -885,6 +886,9 @@ impl InferenceRule { union_cast_method_impl!(as_literal_union, InferenceRuleLiteralUnion, InferenceRule::LiteralUnion); union_cast_method_impl!(as_literal_array, InferenceRuleLiteralArray, InferenceRule::LiteralArray); union_cast_method_impl!(as_literal_tuple, InferenceRuleLiteralTuple, InferenceRule::LiteralTuple); + union_cast_method_impl!(as_cast_expr, InferenceRuleCastExpr, InferenceRule::CastExpr); + union_cast_method_impl!(as_call_expr, InferenceRuleCallExpr, InferenceRule::CallExpr); + union_cast_method_impl!(as_variable_expr, InferenceRuleVariableExpr, InferenceRule::VariableExpr); } struct InferenceRuleTemplate { @@ -1002,13 +1006,7 @@ struct InferenceRuleCallExpr { /// Data associated with a variable expression: an expression that reads the /// value from a variable. struct InferenceRuleVariableExpr { - variable_index: InferNodeIndex, -} - -/// Data associated with a variable: keeping track of all the variable -/// expressions in which it is used. -struct InferenceRuleVariable { - used_at_indices: Vec, + var_data_index: VarDataIndex, // shared variable information } /// This particular visitor will recurse depth-first into the AST and ensures @@ -1027,11 +1025,13 @@ pub(crate) struct PassTyping { bool_buffer: ScopedBuffer, index_buffer: ScopedBuffer, poly_progress_buffer: ScopedBuffer, + temp_type_parts: Vec, // Mapping from parser type to inferred type. We attempt to continue to // specify these types until we're stuck or we've fully determined the type. - var_types: HashMap, // types of variables + var_types: HashMap, // types of variables infer_nodes: Vec, // will be transferred to type table at end poly_data: Vec, // data for polymorph inference + var_data: Vec, // Keeping track of which expressions need to be reinferred because the // expressions they're linked to made progression on an associated type expr_queued: DequeSet, @@ -1040,6 +1040,7 @@ pub(crate) struct PassTyping { /// Generic struct that is used to store inferred types associated with /// polymorphic types. struct PolyData { + first_rule_application: bool, definition_id: DefinitionId, // the definition, only used for user feedback /// Inferred types of the polymorphic variables as they are written down /// at the type's definition. @@ -1087,7 +1088,7 @@ impl PolyData { } } -struct VarData { +struct VarDataDeprecated { /// Type of the variable var_type: InferenceType, /// VariableExpressions that use the variable @@ -1097,7 +1098,7 @@ struct VarData { linked_var: Option, } -impl VarData { +impl VarDataDeprecated { fn new_channel(var_type: InferenceType, other_port: VariableId) -> Self { Self{ var_type, used_at: Vec::new(), linked_var: Some(other_port) } } @@ -1106,6 +1107,13 @@ impl VarData { } } +struct VarData { + var_id: VariableId, + var_type: InferenceType, + used_at: Vec, // of variable expressions + linked_var: Option, +} + impl PassTyping { pub(crate) fn new() -> Self { PassTyping { @@ -1236,7 +1244,7 @@ impl PassTyping { let param = &ctx.heap[param_id]; let var_type = self.determine_inference_type_from_parser_type_elements(¶m.parser_type.elements, true); debug_assert!(var_type.is_done, "expected component arguments to be concrete types"); - self.var_types.insert(param_id, VarData::new_local(var_type)); + self.var_types.insert(param_id, VarDataDeprecated::new_local(var_type)); } section.forget(); @@ -1277,7 +1285,7 @@ impl PassTyping { let param = &ctx.heap[param_id]; let var_type = self.determine_inference_type_from_parser_type_elements(¶m.parser_type.elements, true); debug_assert!(var_type.is_done, "expected function arguments to be concrete types"); - self.var_types.insert(param_id, VarData::new_local(var_type)); + self.var_types.insert(param_id, VarDataDeprecated::new_local(var_type)); } section.forget(); @@ -1314,10 +1322,14 @@ impl PassTyping { let memory_stmt = &ctx.heap[id]; let initial_expr_id = memory_stmt.initial_expr; - // Setup memory statement inference let local = &ctx.heap[memory_stmt.variable]; let var_type = self.determine_inference_type_from_parser_type_elements(&local.parser_type.elements, true); - self.var_types.insert(memory_stmt.variable, VarData::new_local(var_type)); + self.var_data.push(VarData{ + var_id: memory_stmt.variable, + var_type, + used_at: Vec::new(), + linked_var: None, + }); // Process the initial value self.visit_assignment_expr(ctx, initial_expr_id)?; @@ -1328,13 +1340,26 @@ impl PassTyping { fn visit_local_channel_stmt(&mut self, ctx: &mut Ctx, id: ChannelStatementId) -> VisitorResult { let channel_stmt = &ctx.heap[id]; + let from_var_index = self.var_data.len() as VarDataIndex; + let to_var_index = from_var_index + 1; + let from_local = &ctx.heap[channel_stmt.from]; let from_var_type = self.determine_inference_type_from_parser_type_elements(&from_local.parser_type.elements, true); - self.var_types.insert(from_local.this, VarData::new_channel(from_var_type, channel_stmt.to)); + self.var_data.push(VarData{ + var_id: channel_stmt.from, + var_type: from_var_type, + used_at: Vec::new(), + linked_var: Some(to_var_index), + }); let to_local = &ctx.heap[channel_stmt.to]; let to_var_type = self.determine_inference_type_from_parser_type_elements(&to_local.parser_type.elements, true); - self.var_types.insert(to_local.this, VarData::new_channel(to_var_type, channel_stmt.from)); + self.var_data.push(VarData{ + var_id: channel_stmt.to, + var_type: to_var_type, + used_at: Vec::new(), + linked_var: Some(from_var_index), + }); Ok(()) } @@ -1489,7 +1514,8 @@ impl PassTyping { }); self.parent_index = old_parent; - self.progress_assignment_expr(ctx, id) + self.progress_inference_rule_tri_equal_args(ctx, self_index)?; + return Ok(self_index); } fn visit_binding_expr(&mut self, ctx: &mut Ctx, id: BindingExpressionId) -> VisitExprResult { @@ -1513,7 +1539,8 @@ impl PassTyping { }); self.parent_index = old_parent; - self.progress_binding_expr(ctx, id) + self.progress_inference_rule_tri_equal_args(ctx, self_index)?; + return Ok(self_index); } fn visit_conditional_expr(&mut self, ctx: &mut Ctx, id: ConditionalExpressionId) -> VisitExprResult { @@ -1542,7 +1569,8 @@ impl PassTyping { }); self.parent_index = old_parent; - self.progress_conditional_expr(ctx, id) + self.progress_inference_rule_tri_equal_all(ctx, self_index)?; + return Ok(self_index); } fn visit_binary_expr(&mut self, ctx: &mut Ctx, id: BinaryExpressionId) -> VisitExprResult { @@ -1556,7 +1584,7 @@ impl PassTyping { let lhs_expr_id = binary_expr.left; let rhs_expr_id = binary_expr.right; - let parent_index = self.parent_index.replace(self_index); + let old_parent = self.parent_index.replace(self_index); let left_index = self.visit_expr(ctx, lhs_expr_id)?; let right_index = self.visit_expr(ctx, rhs_expr_id)?; @@ -1604,7 +1632,8 @@ impl PassTyping { node.inference_rule = inference_rule; self.parent_index = old_parent; - self.progress_binary_expr(ctx, id) + self.progress_inference_rule(ctx, self_index)?; + return Ok(self_index); } fn visit_unary_expr(&mut self, ctx: &mut Ctx, id: UnaryExpressionId) -> VisitExprResult { @@ -1635,7 +1664,8 @@ impl PassTyping { }); self.parent_index = old_parent; - self.progress_unary_expr(ctx, id) + self.progress_inference_rule_bi_equal(ctx, self_index)?; + return Ok(self_index); } fn visit_indexing_expr(&mut self, ctx: &mut Ctx, id: IndexingExpressionId) -> VisitExprResult { @@ -1656,7 +1686,8 @@ impl PassTyping { }); self.parent_index = old_parent; - self.progress_indexing_expr(ctx, id) + self.progress_inference_rule_indexing_expr(ctx, self_index)?; + return Ok(self_index); } fn visit_slicing_expr(&mut self, ctx: &mut Ctx, id: SlicingExpressionId) -> VisitExprResult { @@ -1679,7 +1710,8 @@ impl PassTyping { }); self.parent_index = old_parent; - self.progress_slicing_expr(ctx, id) + self.progress_inference_rule_slicing_expr(ctx, self_index)?; + return Ok(self_index); } fn visit_select_expr(&mut self, ctx: &mut Ctx, id: SelectExpressionId) -> VisitExprResult { @@ -1693,12 +1725,23 @@ impl PassTyping { let subject_index = self.visit_expr(ctx, subject_expr_id)?; let node = &mut self.infer_nodes[self_index]; - node.inference_rule = InferenceRule::SelectExpr(InferenceRuleSelectExpr{ - subject_index, - }); + let inference_rule = match &ctx.heap[id].kind { + SelectKind::StructField(field_identifier) => + InferenceRule::SelectStructField(InferenceRuleSelectStructField{ + subject_index, + selected_field: field_identifier.clone(), + }), + SelectKind::TupleMember(member_index) => + InferenceRule::SelectTupleMember(InferenceRuleSelectTupleMember{ + subject_index, + selected_index: *member_index, + }), + }; + node.inference_rule = inference_rule; self.parent_index = old_parent; - self.progress_select_expr(ctx, id) + self.progress_inference_rule(ctx, self_index)?; + return Ok(self_index); } fn visit_literal_expr(&mut self, ctx: &mut Ctx, id: LiteralExpressionId) -> VisitExprResult { @@ -1760,9 +1803,7 @@ impl PassTyping { let extra_index = self.insert_initial_enum_polymorph_data(ctx, id); let node = &mut self.infer_nodes[self_index]; node.poly_data_index = extra_index; - node.inference_rule = InferenceRule::LiteralEnum(InferenceRuleLiteralEnum{ - - }); + node.inference_rule = InferenceRule::LiteralEnum; }, Literal::Union(literal) => { // May carry subexpressions and polymorphic arguments @@ -1803,7 +1844,8 @@ impl PassTyping { } self.parent_index = old_parent; - self.progress_literal_expr(ctx, id) + self.progress_inference_rule(ctx, self_index)?; + return Ok(self_index); } fn visit_cast_expr(&mut self, ctx: &mut Ctx, id: CastExpressionId) -> VisitExprResult { @@ -1822,7 +1864,18 @@ impl PassTyping { }); self.parent_index = old_parent; - self.progress_cast_expr(ctx, id) + + // The cast expression is a bit special at this point: the progression + // function simply makes sure input/output types are compatible. But if + // the programmer explicitly specified the output type, then we can + // already perform that inference rule here. + { + let specified_type = self.determine_inference_type_from_parser_type_elements(&cast_expr.to_type.elements, true); + let _progress = self.apply_template_constraint(ctx, self_index, &specified_type.parts)?; + } + + self.progress_inference_rule_cast_expr(ctx, self_index)?; + return Ok(self_index); } fn visit_call_expr(&mut self, ctx: &mut Ctx, id: CallExpressionId) -> VisitExprResult { @@ -1856,35 +1909,58 @@ impl PassTyping { }); self.parent_index = old_parent; - self.progress_call_expr(ctx, id) + self.progress_inference_rule_call_expr(ctx, self_index)?; + return Ok(self_index); } fn visit_variable_expr(&mut self, ctx: &mut Ctx, id: VariableExpressionId) -> VisitExprResult { let upcast_id = id.upcast(); - self.insert_initial_inference_node(ctx, upcast_id)?; + let self_index = self.insert_initial_inference_node(ctx, upcast_id)?; let var_expr = &ctx.heap[id]; debug_assert!(var_expr.declaration.is_some()); + let old_parent = self.parent_index.replace(self_index); - // Not pretty: if a binding expression, then this is the first time we - // encounter the variable, so we still need to insert the variable data. let declaration = &ctx.heap[var_expr.declaration.unwrap()]; - if !self.var_types.contains_key(&declaration.this) { - debug_assert!(declaration.kind == VariableKind::Binding); + let mut var_data_index = None; + for (index, var_data) in self.var_data.iter().enumerate() { + if var_data.var_id == declaration.this { + var_data_index = Some(index); + break; + } + } + + let var_data_index = if let Some(var_data_index) = var_data_index { + let var_data = &mut self.var_data[var_data_index]; + var_data.used_at.push(self_index); + + var_data_index + } else { + // If we're in a binding expression then it might the first time we + // encounter the variable, so add a `VarData` entry. + debug_assert_eq!(declaration.kind, VariableKind::Binding); let var_type = self.determine_inference_type_from_parser_type_elements( &declaration.parser_type.elements, true ); - self.var_types.insert(declaration.this, VarData{ + let var_data_index = self.var_data.len(); + self.var_data.push(VarData{ + var_id: declaration.this, var_type, - used_at: vec![upcast_id], - linked_var: None + used_at: vec![self_index], + linked_var: None, }); - } else { - let var_data = self.var_types.get_mut(&declaration.this).unwrap(); - var_data.used_at.push(upcast_id); - } - self.progress_variable_expr(ctx, id) + var_data_index + }; + + let node = &mut self.infer_nodes[self_index]; + node.inference_rule = InferenceRule::VariableExpr(InferenceRuleVariableExpr{ + var_data_index, + }); + + self.parent_index = old_parent; + self.progress_inference_rule_variable_expr(ctx, self_index)?; + return Ok(self_index); } } @@ -2100,6 +2176,60 @@ impl PassTyping { Ok(()) } + fn progress_inference_rule(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + use InferenceRule as IR; + + let node = &self.infer_nodes[node_index]; + match &node.inference_rule { + IR::Noop => + return Ok(()), + IR::MonoTemplate(_) => + self.progress_inference_rule_mono_template(ctx, node_index), + IR::BiEqual(_) => + self.progress_inference_rule_bi_equal(ctx, node_index), + IR::TriEqualArgs(_) => + self.progress_inference_rule_tri_equal_args(ctx, node_index), + IR::TriEqualAll(_) => + self.progress_inference_rule_tri_equal_all(ctx, node_index), + IR::Concatenate(_) => + self.progress_inference_rule_concatenate(ctx, node_index), + IR::IndexingExpr(_) => + self.progress_inference_rule_indexing_expr(ctx, node_index), + IR::SlicingExpr(_) => + self.progress_inference_rule_slicing_expr(ctx, node_index), + IR::SelectStructField(_) => + self.progress_inference_rule_select_struct_field(ctx, node_index), + IR::SelectTupleMember(_) => + self.progress_inference_rule_select_tuple_member(ctx, node_index), + IR::LiteralStruct(_) => + self.progress_inference_rule_literal_struct(ctx, node_index), + IR::LiteralEnum => + self.progress_inference_rule_literal_enum(ctx, node_index), + IR::LiteralUnion(_) => + self.progress_inference_rule_literal_union(ctx, node_index), + IR::LiteralArray(_) => + self.progress_inference_rule_literal_array(ctx, node_index), + IR::LiteralTuple(_) => + self.progress_inference_rule_literal_tuple(ctx, node_index), + IR::CastExpr(_) => + self.progress_inference_rule_cast_expr(ctx, node_index), + IR::CallExpr(_) => + self.progress_inference_rule_call_expr(ctx, node_index), + IR::VariableExpr(_) => + self.progress_inference_rule_variable_expr(ctx, node_index), + } + } + + fn progress_inference_rule_mono_template(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let rule = node.inference_rule.as_mono_template(); + + let progress = self.progress_template(ctx, node_index, rule.application, rule.template)?; + if progress { self.queue_node_parent(node_index); } + + return Ok(()); + } + fn progress_inference_rule_bi_equal(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { let node = &self.infer_nodes[node_index]; let rule = node.inference_rule.as_bi_equal(); @@ -2148,7 +2278,7 @@ impl PassTyping { return Ok(()); } - fn progress_inference_rule_concatenate(&mut self, ctx: &mut Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + fn progress_inference_rule_concatenate(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { let node = &self.infer_nodes[node_index]; let rule = node.inference_rule.as_concatenate(); let arg1_index = rule.argument1_index; @@ -2270,7 +2400,7 @@ impl PassTyping { return Ok(()); } - fn progress_inference_rule_select_struct_field(&mut self, ctx: &mut Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + fn progress_inference_rule_select_struct_field(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { let node = &self.infer_nodes[node_index]; let rule = node.inference_rule.as_select_struct_field(); @@ -2383,31 +2513,16 @@ impl PassTyping { if progress_field_1 || progress_field_2 { self.queue_node_parent(node_index); } poly_progress_section.forget(); - + self.finish_polydata_constraint(node_index); return Ok(()) } - fn progress_inference_rule_select_tuple_member(&mut self, ctx: &mut Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + fn progress_inference_rule_select_tuple_member(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { let node = &self.infer_nodes[node_index]; let rule = node.inference_rule.as_select_tuple_member(); let subject_index = rule.subject_index; let tuple_member_index = rule.selected_index; - fn get_tuple_size_from_inference_type(inference_type: &InferenceType) -> Result, ()> { - for part in &inference_type.parts { - if part.is_marker() { continue; } - if !part.is_concrete() { break; } - - if let InferenceTypePart::Tuple(size) = part { - return Ok(Some(*size)); - } else { - return Err(()); // not a tuple! - } - } - - return Ok(None); - } - if node.field_or_monomorph_index < 0 { let subject_type = &self.infer_nodes[subject_index].expr_type; let tuple_size = get_tuple_size_from_inference_type(subject_type); @@ -2512,7 +2627,7 @@ impl PassTyping { if progress_literal_1 || progress_literal_2 { self.queue_node_parent(node_index); } poly_progress_section.forget(); - + self.finish_polydata_constraint(node_index); return Ok(()) } @@ -2536,6 +2651,7 @@ impl PassTyping { if progress_literal_1 || progress_literal_2 { self.queue_node_parent(node_index); } poly_progress_section.forget(); + self.finish_polydata_constraint(node_index); return Ok(()); } @@ -2581,6 +2697,7 @@ impl PassTyping { if progress_literal_1 || progress_literal_2 { self.queue_node_parent(node_index); } poly_progress_section.forget(); + self.finish_polydata_constraint(node_index); return Ok(()); } @@ -2627,7 +2744,250 @@ impl PassTyping { fn progress_inference_rule_literal_tuple(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { let node = &self.infer_nodes[node_index]; let rule = node.inference_rule.as_literal_tuple(); - + + let element_indices = self.index_buffer.start_section_initialized(&rule.element_indices); + + // Check if we need to apply the initial tuple template type. Note that + // this is a hacky check. + let num_tuple_elements = rule.element_indices.len(); + let mut template_type = Vec::with_capacity(num_tuple_elements + 1); // TODO: @performance + template_type.push(InferenceTypePart::Tuple(num_tuple_elements as u32)); + for _ in 0..num_tuple_elements { + template_type.push(InferenceTypePart::Unknown); + } + + let mut progress_literal = self.apply_template_constraint(ctx, node_index, &template_type)?; + + // Because of the (early returning error) check above, we're certain + // that the tuple has the correct number of elements. Now match each + // element expression type to the tuple subtype. + let mut element_subtree_start_index = 1; // first element is InferenceTypePart::Tuple + for element_node_index in element_indices.iter_copied() { + let (progress_literal_element, progress_element) = self.apply_equal2_constraint( + ctx, node_index, node_index, element_subtree_start_index, element_node_index, 0 + )?; + + progress_literal = progress_literal || progress_literal_element; + if progress_element { + self.queue_node(element_node_index); + } + + // Prepare for next element + let subtree_end_index = InferenceType::find_subtree_end_idx(&node.expr_type.parts, element_subtree_start_index); + element_subtree_start_index = subtree_end_index; + } + debug_assert_eq!(element_subtree_end_index, node.expr_type.parts.len()); + + if progress_literal { self.queue_node_parent(node_index); } + + element_indices.forget(); + return Ok(()); + } + + fn progress_inference_rule_cast_expr(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let rule = node.inference_rule.as_cast_expr(); + let subject_index = rule.subject_index; + let subject = &self.infer_nodes[subject_index]; + + // Make sure that both types are completely done. Note: a cast + // expression cannot really infer anything between the subject and the + // output type, we can only make sure that, at the end, the cast is + // correct. + if !node.expr_type.is_done || !subject.expr_type.is_done { + return Ok(()); + } + + // Both types are known, currently the only valid casts are bool, + // integer and character casts. + fn is_bool_int_or_char(parts: &[InferenceTypePart]) -> bool { + let mut index = 0; + while index < parts.len() { + let part = &parts[index]; + if !part.is_marker() { break; } + index += 1; + } + + debug_assert!(index != parts.len()); + let part = &parts[index]; + if ( + *part == InferenceTypePart::Bool || + *part == InferenceTypePart::Character || + part.is_concrete_integer() + ) { + debug_assert!(index + 1 == parts.len()); // type is done, first part does not have children -> must be at end + return true; + } else { + return false; + } + } + + let is_valid = if is_bool_int_or_char(&node.expr_type.parts) && is_bool_int_or_char(&subject.expr_type.parts) { + true + } else if InferenceType::check_subtrees(&node.expr_type.parts, 0, &subject.expr_type.parts, 0) { + // again: check_subtrees is sufficient since both types are done + true + } else { + false + }; + + if !is_valid { + let cast_expr = &ctx.heap[node.expr_id]; + let subject_expr = &ctx.heap[subject.expr_id]; + return Err(ParseError::new_error_str_at_span( + &ctx.module().source, cast_expr.full_span(), "invalid casting operation" + ).with_info_at_span( + &ctx.module.source, subject_expr.full_span(), format!( + "cannot cast the argument type '{}' to the type '{}'", + subject.expr_type.display_name(&ctx.heap), + node.expr_type.display_name(&ctx.heap) + ) + )); + } + + return Ok(()) + } + + fn progress_inference_rule_call_expr(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &self.infer_nodes[node_index]; + let rule = node.inference_rule.as_call_expr(); + + let mut poly_progress_section = self.poly_progress_buffer.start_section(); + let argument_node_indices = self.index_buffer.start_section_initialized(&rule.argument_indices); + + // Perform inference on arguments to function, while trying to figure + // out the polymorphic variables + for (argument_index, argument_node_index) in argument_node_indices.iter_copied().enumerate() { + let argument_expr_id = self.infer_nodes[argument_node_index].expr_id; + let (_, progress_argument) = self.apply_polydata_equal2_constraint( + ctx, node_index, argument_expr_id, "argument's", + PolyDataTypeIndex::Associated(argument_index), 0, + argument_node_index, 0, &mut poly_progress_section + )?; + + if progress_argument { self.queue_node(argument_node_index); } + } + + // Same for the return type. + let call_expr_id = node.expr_id; + let (_, progress_call_1) = self.apply_polydata_equal2_constraint( + ctx, node_index, call_expr_id, "return", + PolyDataTypeIndex::Returned, 0, + node_index, 0, &mut poly_progress_section + )?; + + // We will now apply any progression in the polymorphic variable type + // back to the arguments. + for (argument_index, argument_node_index) in argument_node_indices.iter_copied().enumerate() { + let progress_argument = self.apply_polydata_polyvar_constraint( + ctx, node_index, PolyDataTypeIndex::Associated(argument_index), + argument_node_index, &poly_progress_section + ); + + if progress_argument { self.queue_node(argument_node_index); } + } + + // And back to the return type. + let progress_call_2 = self.apply_polydata_polyvar_constraint( + ctx, node_index, PolyDataTypeIndex::Returned, + node_index, &poly_progress_section + ); + + if progress_call_1 || progress_call_2 { self.queue_node_parent(node_index); } + + poly_progress_section.forget(); + argument_node_indices.forget(); + + self.finish_polydata_constraint(node_index); + return Ok(()) + } + + fn progress_inference_rule_variable_expr(&mut self, ctx: &Ctx, node_index: InferNodeIndex) -> Result<(), ParseError> { + let node = &mut self.infer_nodes[node_index]; + let rule = node.inference_rule.as_variable_expr(); + let var_data_index = rule.var_data_index; + let var_data = &mut self.var_data[var_data_index]; + + // Apply inference to the shared variable type and the expression type + let shared_type: *mut _ = &mut var_data.var_type; + let expr_type: *mut _ = &mut node.expr_type; + + let inference_result = unsafe { + // safety: vectors exist in different storage vectors, so cannot alias + InferenceType::infer_subtrees_for_both_types(shared_type, 0, expr_type, 0) + }; + + if inference_result == DualInferenceResult::Incompatible { + return Err(self.construct_variable_type_error(ctx, node_index)); + } + + let progress_var_data = inference_result.modified_lhs(); + let progress_expr = inference_result.modified_rhs(); + + if progress_var_data { + // We progressed the type of the shared variable, so propagate this + // to all associated variable expressions (and relatived variables). + for other_node_index in var_data.used_at.iter().copied() { + if other_node_index != node_index { + self.queue_node(other_node_index); + } + } + + if let Some(linked_var_data_index) = var_data.linked_var { + // Only perform one-way inference, progressing the linked + // variable. + // note: because this "linking" is used only for channels, we + // will start inference one level below the top-level in the + // type tree (i.e. ensure `T` in `in` and `out` is equal). + debug_assert!( + var_data.var_type.parts[0] == InferenceTypePart::Input || + var_data.var_type.parts[0] == InferenceTypePart::Output + ); + let this_var_type: *const _ = &var_data.var_type; + let linked_var_data = &mut self.var_data[linked_var_data_index]; + debug_assert!( + linked_var_data.var_type.parts[0] == InferenceTypePart::Input || + linked_var_data.var_type.parts[0] == InferenceTypePart::Output + ); + + // safety: by construction var_data_index and linked_var_data_index cannot be the + // same, hence we're not aliasing here. + let inference_result = InferenceType::infer_subtree_for_single_type( + &mut linked_var_data.var_type, 1, + unsafe{ &(*this_var_type).parts }, 1, false + ); + match inference_result { + SingleInferenceResult::Modified => { + for used_at in linked_var_data.used_at.iter().copied() { + self.queue_node(used_at); + } + }, + SingleInferenceResult::Unmodified => {}, + SingleInferenceResult::Incompatible => { + let var_data_this = &self.var_data[var_data_index]; + let var_decl_this = &ctx.heap[var_data_this.var_id]; + let var_data_linked = &self.var_data[linked_var_data_index]; + let var_decl_linked = &ctx.heap[var_data_linked.var_id]; + + return Err(ParseError::new_error_at_span( + &ctx.module().source, var_decl_this.identifier.span, format!( + "conflicting types for this channel, this port has type '{}'", + var_data_this.var_type.display_name(&ctx.heap) + ) + ).with_info_at_span( + &ctx.module().source, var_decl_linked.identifier.span, format!( + "while this port has type '{}'", + var_data_linked.var_type.display_name(&ctx.heap) + ) + )); + } + } + } + } + + if progress_expr { self.queue_node_parent(node_index); } + + return Ok(()); } fn progress_template(&mut self, ctx: &Ctx, node_index: InferNodeIndex, application: InferenceRuleTemplateApplication, template: &[InferenceTypePart]) -> Result { @@ -4011,7 +4371,7 @@ impl PassTyping { /// and another inferred type. If any progress is made in the `PolyData` /// struct then the affected polymorphic variables are updated as well. /// - /// Because a lot of types/expressions are involved in polymorphic type + /// Because a lot of types/expressions are involved in polymorphic typFe /// inference, some explanation: "outer_node" refers to the main expression /// that is the root cause of type inference (e.g. a struct literal /// expression, or a tuple member select expression). Associated with that @@ -4124,6 +4484,12 @@ impl PassTyping { let poly_data_index = self.infer_nodes[outer_node_index].poly_data_index; let poly_data = &mut self.poly_data[poly_data_index]; + // Early exit, most common case (literals or functions calls which are + // actually not polymorphic) + if !poly_data.first_rule_application && poly_progress_section.len() == 0 { + return false; + } + // safety: we're borrowing from two distinct fields, so should be fine let poly_data_type = poly_data.get_type_mut(poly_data_type_index); let mut last_start_index = 0; @@ -4131,7 +4497,8 @@ impl PassTyping { while let Some((poly_var_index, poly_var_start_index)) = poly_data_type.find_marker(last_start_index) { let poly_var_end_index = InferenceType::find_subtree_end_idx(&poly_data_type.parts, poly_var_start_index); - if poly_progress_section.contains(&poly_var_index) { + + if poly_data.first_rule_application || poly_progress_section.contains(&poly_var_index) { // We have updated this polymorphic variable, so try updating it // in the PolyData type let modified_in_poly_data = match InferenceType::infer_subtree_for_single_type( @@ -4169,6 +4536,14 @@ impl PassTyping { } } + /// Should be called after completing one full round of applying polydata + /// constraints. + fn finish_polydata_constraint(&mut self, outer_node_index: InferNodeIndex) { + let poly_data_index = self.infer_nodes[outer_node_index].poly_data_index; + let poly_data = &mut self.poly_data[poly_data_index]; + poly_data.first_rule_application = false; + } + /// Applies an equal2 constraint between a signature type (e.g. a function /// argument or struct field) and an expression whose type should match that /// expression. If we make progress on the signature, then we try to see if @@ -4960,6 +5335,29 @@ impl PassTyping { ) } + fn construct_variable_type_error( + &self, ctx: &Ctx, node_index: InferNodeIndex, + ) -> ParseError { + let node = &self.infer_nodes[node_index]; + let rule = node.inference_rule.as_variable_expr(); + + let var_data = &self.var_data[rule.var_data_index]; + let var_decl = &ctx.heap[var_data.var_id]; + let var_expr = &ctx.heap[node.expr_id]; + + return ParseError::new_error_at_span( + &ctx.module().source, var_decl.identifier.span, format!( + "conflicting types for this variable, previously assigned the type '{}'", + var_data.var_type.display_name(&ctx.heap) + ) + ).with_info_at_span( + &ctx.module().source, var_expr.full_span(), format!( + "but inferred to have incompatible type '{}' here", + node.expr_type.display_name(&ctx.heap) + ) + ); + } + /// Constructs a human interpretable error in the case that type inference /// on a polymorphic variable to a function call or literal construction /// failed. This may only be caused by a pair of inference types (which may @@ -5194,6 +5592,21 @@ impl PassTyping { } } +fn get_tuple_size_from_inference_type(inference_type: &InferenceType) -> Result, ()> { + for part in &inference_type.parts { + if part.is_marker() { continue; } + if !part.is_concrete() { break; } + + if let InferenceTypePart::Tuple(size) = part { + return Ok(Some(*size)); + } else { + return Err(()); // not a tuple! + } + } + + return Ok(None); +} + #[cfg(test)] mod tests { use super::*;