diff --git a/src/protocol/parser/pass_rewriting.rs b/src/protocol/parser/pass_rewriting.rs index 5884cdca726ff74bb13f9733d73ab3bd8e5f8f6d..d5080c1c04c63d750f292e144b621d08a2f66b4e 100644 --- a/src/protocol/parser/pass_rewriting.rs +++ b/src/protocol/parser/pass_rewriting.rs @@ -5,6 +5,7 @@ use super::visitor::*; pub(crate) struct PassRewriting { current_scope: ScopeId, + current_procedure_id: ProcedureDefinitionId, definition_buffer: ScopedBuffer, statement_buffer: ScopedBuffer, call_expr_buffer: ScopedBuffer, @@ -16,6 +17,7 @@ impl PassRewriting { pub(crate) fn new() -> Self { Self{ current_scope: ScopeId::new_invalid(), + current_procedure_id: ProcedureDefinitionId::new_invalid(), definition_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_LARGE), statement_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL), call_expr_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL), @@ -49,6 +51,7 @@ impl Visitor for PassRewriting { let definition = &ctx.heap[id]; let body_id = definition.body; self.current_scope = definition.scope; + self.current_procedure_id = id; return self.visit_block_stmt(ctx, body_id); } @@ -108,7 +111,11 @@ impl Visitor for PassRewriting { // Utility for the last stage of rewriting process. Note that caller // still needs to point the end of the if-statement to the end of the // replacement statement of the select statement. - fn transform_select_case_code(ctx: &mut Ctx, select_id: SelectStatementId, case_index: usize, select_var_id: VariableId) -> (IfStatementId, EndIfStatementId) { + fn transform_select_case_code( + ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, + select_id: SelectStatementId, case_index: usize, + select_var_id: VariableId, select_var_type_id: TypeIdReference + ) -> (IfStatementId, EndIfStatementId) { // Retrieve statement IDs associated with case let case = &ctx.heap[select_id].cases[case_index]; let case_guard_id = case.guard; @@ -116,7 +123,7 @@ impl Visitor for PassRewriting { let case_scope_id = case.scope; // Create the if-statement for the result of the select statement - let compare_expr_id = create_ast_equality_comparison_expr(ctx, select_var_id, case_index as u64); + let compare_expr_id = create_ast_equality_comparison_expr(ctx, containing_procedure_id, select_var_id, select_var_type_id, case_index as u64); let true_case = IfStatementCase{ body: case_guard_id, // which is linked up to the body scope: case_scope_id, @@ -177,19 +184,21 @@ impl Visitor for PassRewriting { for port_var_idx in 0..call_id_section.len() { let get_call_expr_id = call_id_section[port_var_idx]; let port_expr_id = expr_id_section[port_var_idx]; + let port_type_index = ctx.heap[port_expr_id].type_index(); + let port_type_ref = TypeIdReference::IndirectSameAsExpr(port_type_index); // Move the port expression such that it gets assigned to a temporary variable let variable_id = create_ast_variable(ctx, outer_scope_id); - let variable_decl_stmt_id = create_ast_variable_declaration_stmt(ctx, variable_id, port_expr_id); + let variable_decl_stmt_id = create_ast_variable_declaration_stmt(ctx, self.current_procedure_id, variable_id, port_type_ref, port_expr_id); // Replace the original port expression in the call with a reference // to the replacement variable - let variable_expr_id = create_ast_variable_expr(ctx, variable_id); + let variable_expr_id = create_ast_variable_expr(ctx, self.current_procedure_id, variable_id, port_type_ref); let call_expr = &mut ctx.heap[get_call_expr_id]; call_expr.arguments[0] = variable_expr_id.upcast(); transformed_stmts.push(variable_decl_stmt_id.upcast().upcast()); - locals.push(variable_id); + locals.push((variable_id, port_type_ref)); } // Insert runtime calls that facilitate the semantics of the select @@ -197,14 +206,14 @@ impl Visitor for PassRewriting { // Create the call that indicates the start of the select block { - let num_cases_expression_id = create_ast_literal_integer_expr(ctx, total_num_cases as u64); - let num_ports_expression_id = create_ast_literal_integer_expr(ctx, total_num_ports as u64); + let num_cases_expression_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, total_num_cases as u64); + let num_ports_expression_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, total_num_ports as u64); let arguments = vec![ num_cases_expression_id.upcast(), num_ports_expression_id.upcast() ]; - let call_expression_id = create_ast_call_expr(ctx, Method::SelectStart, &mut self.expression_buffer, arguments); + let call_expression_id = create_ast_call_expr(ctx, self.current_procedure_id, Method::SelectStart, &mut self.expression_buffer, arguments); let call_statement_id = create_ast_expression_stmt(ctx, call_expression_id.upcast()); transformed_stmts.push(call_statement_id.upcast()); @@ -220,10 +229,10 @@ impl Visitor for PassRewriting { for case_port_index in 0..case_num_ports { // Arguments to runtime call - let port_variable_id = locals[total_port_index]; // so far this variable contains the temporary variables for the port expressions - let case_index_expr_id = create_ast_literal_integer_expr(ctx, case_index as u64); - let port_index_expr_id = create_ast_literal_integer_expr(ctx, case_port_index as u64); - let port_variable_expr_id = create_ast_variable_expr(ctx, port_variable_id); + let (port_variable_id, port_variable_type) = locals[total_port_index]; // so far this variable contains the temporary variables for the port expressions + let case_index_expr_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, case_index as u64); + let port_index_expr_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, case_port_index as u64); + let port_variable_expr_id = create_ast_variable_expr(ctx, self.current_procedure_id, port_variable_id, port_variable_type); let runtime_call_arguments = vec![ case_index_expr_id.upcast(), port_index_expr_id.upcast(), @@ -231,7 +240,7 @@ impl Visitor for PassRewriting { ]; // Create runtime call, then store it - let runtime_call_expr_id = create_ast_call_expr(ctx, Method::SelectRegisterCasePort, &mut self.expression_buffer, runtime_call_arguments); + let runtime_call_expr_id = create_ast_call_expr(ctx, self.current_procedure_id, Method::SelectRegisterCasePort, &mut self.expression_buffer, runtime_call_arguments); let runtime_call_stmt_id = create_ast_expression_stmt(ctx, runtime_call_expr_id.upcast()); transformed_stmts.push(runtime_call_stmt_id.upcast()); @@ -244,11 +253,12 @@ impl Visitor for PassRewriting { // Create the variable that will hold the result of a completed select // block. Then create the runtime call that will produce this result let select_variable_id = create_ast_variable(ctx, outer_scope_id); - locals.push(select_variable_id); + let select_variable_type = TypeIdReference::DirectTypeId(ctx.arch.uint32_type_id); + locals.push((select_variable_id, select_variable_type)); { - let runtime_call_expr_id = create_ast_call_expr(ctx, Method::SelectWait, &mut self.expression_buffer, Vec::new()); - let variable_stmt_id = create_ast_variable_declaration_stmt(ctx, select_variable_id, runtime_call_expr_id.upcast()); + let runtime_call_expr_id = create_ast_call_expr(ctx, self.current_procedure_id, Method::SelectWait, &mut self.expression_buffer, Vec::new()); + let variable_stmt_id = create_ast_variable_declaration_stmt(ctx, self.current_procedure_id, select_variable_id, select_variable_type, runtime_call_expr_id.upcast()); transformed_stmts.push(variable_stmt_id.upcast().upcast()); } @@ -258,7 +268,7 @@ impl Visitor for PassRewriting { // Now we transform each of the select block case's guard and code into // a chained if-else statement. if total_num_cases > 0 { - let (if_stmt_id, end_if_stmt_id) = transform_select_case_code(ctx, id, 0, select_variable_id); + let (if_stmt_id, end_if_stmt_id) = transform_select_case_code(ctx, self.current_procedure_id, id, 0, select_variable_id, select_variable_type); let first_end_if_stmt = &mut ctx.heap[end_if_stmt_id]; first_end_if_stmt.next = outer_end_block_id.upcast(); @@ -267,7 +277,7 @@ impl Visitor for PassRewriting { transformed_stmts.push(last_if_stmt_id.upcast()); for case_index in 1..total_num_cases { - let (if_stmt_id, end_if_stmt_id) = transform_select_case_code(ctx, id, case_index, select_variable_id); + let (if_stmt_id, end_if_stmt_id) = transform_select_case_code(ctx, self.current_procedure_id, id, case_index, select_variable_id, select_variable_type); let false_case_scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::If(last_if_stmt_id, false))); set_ast_if_statement_false_body(ctx, last_if_stmt_id, last_end_if_stmt_id, IfStatementCase{ body: if_stmt_id.upcast(), scope: false_case_scope_id }); @@ -298,6 +308,12 @@ impl Visitor for PassRewriting { // Utilities to create compiler-generated AST nodes // ----------------------------------------------------------------------------- +#[derive(Clone, Copy)] +enum TypeIdReference { + DirectTypeId(TypeId), + IndirectSameAsExpr(i32), // by type index +} + fn create_ast_variable(ctx: &mut Ctx, scope_id: ScopeId) -> VariableId { let variable_id = ctx.heap.alloc_variable(|this| Variable{ this, @@ -316,22 +332,31 @@ fn create_ast_variable(ctx: &mut Ctx, scope_id: ScopeId) -> VariableId { return variable_id; } -fn create_ast_variable_expr(ctx: &mut Ctx, variable_id: VariableId) -> VariableExpressionId { +fn create_ast_variable_expr(ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, variable_id: VariableId, variable_type_id: TypeIdReference) -> VariableExpressionId { + let variable_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, variable_type_id); return ctx.heap.alloc_variable_expression(|this| VariableExpression{ this, identifier: Identifier::new_empty(InputSpan::new()), declaration: Some(variable_id), used_as_binding_target: false, parent: ExpressionParent::None, - type_index: -1 + type_index: variable_type_index, }); } -fn create_ast_call_expr(ctx: &mut Ctx, method: Method, buffer: &mut ScopedBuffer, arguments: Vec) -> CallExpressionId { +fn create_ast_call_expr(ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, method: Method, buffer: &mut ScopedBuffer, arguments: Vec) -> CallExpressionId { + let call_type_id = match method { + Method::SelectStart => ctx.arch.void_type_id, + Method::SelectRegisterCasePort => ctx.arch.void_type_id, + Method::SelectWait => ctx.arch.uint32_type_id, // TODO: Not pretty, this. Pretty error prone + _ => unreachable!(), // if this goes of, add the appropriate method here. + }; + let expression_ids = buffer.start_section_initialized(&arguments); + let call_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, TypeIdReference::DirectTypeId(call_type_id)); let call_expression_id = ctx.heap.alloc_call_expression(|this| CallExpression{ - this, func_span: InputSpan::new(), + this, full_span: InputSpan::new(), parser_type: ParserType{ elements: Vec::new(), @@ -341,7 +366,7 @@ fn create_ast_call_expr(ctx: &mut Ctx, method: Method, buffer: &mut ScopedBuffer arguments, procedure: ProcedureDefinitionId::new_invalid(), parent: ExpressionParent::None, - type_index: -1, + type_index: call_type_index, }); for argument_index in 0..expression_ids.len() { @@ -353,7 +378,8 @@ fn create_ast_call_expr(ctx: &mut Ctx, method: Method, buffer: &mut ScopedBuffer return call_expression_id; } -fn create_ast_literal_integer_expr(ctx: &mut Ctx, unsigned_value: u64) -> LiteralExpressionId { +fn create_ast_literal_integer_expr(ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, unsigned_value: u64) -> LiteralExpressionId { + let literal_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, TypeIdReference::DirectTypeId(ctx.arch.uint64_type_id)); return ctx.heap.alloc_literal_expression(|this| LiteralExpression{ this, span: InputSpan::new(), @@ -362,13 +388,17 @@ fn create_ast_literal_integer_expr(ctx: &mut Ctx, unsigned_value: u64) -> Litera negated: false, }), parent: ExpressionParent::None, - type_index: -1 + type_index: literal_type_index, }); } -fn create_ast_equality_comparison_expr(ctx: &mut Ctx, variable_id: VariableId, value: u64) -> BinaryExpressionId { - let var_expr_id = create_ast_variable_expr(ctx, variable_id); - let int_expr_id = create_ast_literal_integer_expr(ctx, value); +fn create_ast_equality_comparison_expr( + ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, + variable_id: VariableId, variable_type: TypeIdReference, value: u64 +) -> BinaryExpressionId { + let var_expr_id = create_ast_variable_expr(ctx, containing_procedure_id, variable_id, variable_type); + let int_expr_id = create_ast_literal_integer_expr(ctx, containing_procedure_id, value); + let cmp_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, TypeIdReference::DirectTypeId(ctx.arch.bool_type_id)); let cmp_expr_id = ctx.heap.alloc_binary_expression(|this| BinaryExpression{ this, operator_span: InputSpan::new(), @@ -377,7 +407,7 @@ fn create_ast_equality_comparison_expr(ctx: &mut Ctx, variable_id: VariableId, v operation: BinaryOperator::Equality, right: int_expr_id.upcast(), parent: ExpressionParent::None, - type_index: -1, + type_index: cmp_type_index, }); let var_expr = &mut ctx.heap[var_expr_id]; @@ -402,9 +432,12 @@ fn create_ast_expression_stmt(ctx: &mut Ctx, expression_id: ExpressionId) -> Exp return statement_id; } -fn create_ast_variable_declaration_stmt(ctx: &mut Ctx, variable_id: VariableId, initial_value_expr_id: ExpressionId) -> MemoryStatementId { +fn create_ast_variable_declaration_stmt( + ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, + variable_id: VariableId, variable_type: TypeIdReference, initial_value_expr_id: ExpressionId +) -> MemoryStatementId { // Create the assignment expression, assigning the initial value to the variable - let variable_expr_id = create_ast_variable_expr(ctx, variable_id); + let variable_expr_id = create_ast_variable_expr(ctx, containing_procedure_id, variable_id, variable_type); let assignment_expr_id = ctx.heap.alloc_assignment_expression(|this| AssignmentExpression{ this, operator_span: InputSpan::new(), @@ -605,3 +638,29 @@ fn add_child_scope_to_parent(ctx: &mut Ctx, scope_buffer: &mut ScopedBuffer i32 { + let procedure = &mut ctx.heap[procedure_id]; + let type_index = procedure.monomorphs[0].expr_info.len(); + + match type_id { + TypeIdReference::DirectTypeId(type_id) => { + for monomorph in procedure.monomorphs.iter_mut() { + debug_assert_eq!(monomorph.expr_info.len(), type_index); + monomorph.expr_info.push(ExpressionInfo{ + type_id, + variant: ExpressionInfoVariant::Generic + }); + } + }, + TypeIdReference::IndirectSameAsExpr(source_type_index) => { + for monomorph in procedure.monomorphs.iter_mut() { + debug_assert_eq!(monomorph.expr_info.len(), type_index); + let copied_expr_info = monomorph.expr_info[source_type_index as usize]; + monomorph.expr_info.push(copied_expr_info) + } + } + } + + return type_index as i32; +}