use crate::collections::*; use crate::protocol::*; use super::visitor::*; pub(crate) struct PassRewriting { current_scope: BlockStatementId, statement_buffer: ScopedBuffer, call_expr_buffer: ScopedBuffer, expression_buffer: ScopedBuffer, } impl PassRewriting { pub(crate) fn new() -> Self { Self{ current_scope: BlockStatementId::new_invalid(), statement_buffer: ScopedBuffer::with_capacity(16), call_expr_buffer: ScopedBuffer::with_capacity(16), expression_buffer: ScopedBuffer::with_capacity(16), } } } impl Visitor for PassRewriting { // --- Visiting procedures fn visit_component_definition(&mut self, ctx: &mut Ctx, id: ComponentDefinitionId) -> VisitorResult { let def = &ctx.heap[id]; let body_id = def.body; return self.visit_block_stmt(ctx, body_id); } fn visit_function_definition(&mut self, ctx: &mut Ctx, id: FunctionDefinitionId) -> VisitorResult { let def = &ctx.heap[id]; let body_id = def.body; return self.visit_block_stmt(ctx, body_id); } // --- Visiting statements (that are not the select statement) fn visit_block_stmt(&mut self, ctx: &mut Ctx, id: BlockStatementId) -> VisitorResult { let block_stmt = &ctx.heap[id]; let stmt_section = self.statement_buffer.start_section_initialized(&block_stmt.statements); self.current_scope = id; for stmt_idx in 0..stmt_section.len() { self.visit_stmt(ctx, stmt_section[stmt_idx])?; } stmt_section.forget(); return Ok(()) } fn visit_labeled_stmt(&mut self, ctx: &mut Ctx, id: LabeledStatementId) -> VisitorResult { let labeled_stmt = &ctx.heap[id]; let body_id = labeled_stmt.body; return self.visit_stmt(ctx, body_id); } fn visit_if_stmt(&mut self, ctx: &mut Ctx, id: IfStatementId) -> VisitorResult { let if_stmt = &ctx.heap[id]; let true_body_id = if_stmt.true_case; let false_body_id = if_stmt.false_case; self.visit_stmt(ctx, true_body_id.body)?; if let Some(false_body_id) = false_body_id { self.visit_stmt(ctx, false_body_id.body)?; } return Ok(()) } fn visit_while_stmt(&mut self, ctx: &mut Ctx, id: WhileStatementId) -> VisitorResult { let while_stmt = &ctx.heap[id]; let body_id = while_stmt.body; return self.visit_stmt(ctx, body_id); } fn visit_synchronous_stmt(&mut self, ctx: &mut Ctx, id: SynchronousStatementId) -> VisitorResult { let sync_stmt = &ctx.heap[id]; let body_id = sync_stmt.body; return self.visit_stmt(ctx, body_id); } // --- Visiting the select statement fn visit_select_stmt(&mut self, ctx: &mut Ctx, id: SelectStatementId) -> VisitorResult { // Utility for the last stage of rewriting process. Note that caller // still needs to point the end of the if-statement to the end of the // replacement statement of the select statement. fn transform_select_case_code(ctx: &mut Ctx, select_id: SelectStatementId, case_index: usize, select_var_id: VariableId) -> (IfStatementId, EndIfStatementId) { // Retrieve statement IDs associated with case let case = &ctx.heap[select_id].cases[case_index]; let case_guard_id = case.guard; let case_body_id = case.body; let case_scope_id = case.scope; // Create the if-statement for the result of the select statement let compare_expr_id = create_ast_equality_comparison_expr(ctx, select_var_id, case_index as u64); let true_case = IfStatementCase{ body: case_guard_id, // which is linked up to the body scope: case_scope_id, }; let (if_stmt_id, end_if_stmt_id) = create_ast_if_stmt(ctx, compare_expr_id.upcast(), true_case, None); // Link up body statement to end-if set_ast_statement_next(ctx, case_body_id, end_if_stmt_id.upcast()); return (if_stmt_id, end_if_stmt_id) } // We're going to transform the select statement by a block statement // containing builtin runtime-calls. And to do so we create temporary // variables and move some other statements around. let select_stmt = &ctx.heap[id]; let total_num_cases = select_stmt.cases.len(); let mut total_num_ports = 0; let end_select_stmt_id = select_stmt.end_select; let end_select = &ctx.heap[end_select_stmt_id]; let stmt_id_after_select_stmt = end_select.next; // Put heap IDs into temporary buffers to handle borrowing rules let mut call_id_section = self.call_expr_buffer.start_section(); let mut expr_id_section = self.expression_buffer.start_section(); for case in select_stmt.cases.iter() { total_num_ports += case.involved_ports.len(); for (call_id, expr_id) in case.involved_ports.iter().copied() { call_id_section.push(call_id); expr_id_section.push(expr_id); } } // Transform all of the call expressions by takings its argument (the // port from which we `get`) and turning it into a temporary variable. let mut transformed_stmts = Vec::with_capacity(total_num_ports); // TODO: Recompute this preallocated length, put assert at the end let mut locals = Vec::with_capacity(total_num_ports); for port_var_idx in 0..call_id_section.len() { let get_call_expr_id = call_id_section[port_var_idx]; let port_expr_id = expr_id_section[port_var_idx]; // Move the port expression such that it gets assigned to a temporary variable let variable_id = create_ast_variable(ctx); let variable_decl_stmt_id = create_ast_variable_declaration_stmt(ctx, variable_id, port_expr_id); // Replace the original port expression in the call with a reference // to the replacement variable let variable_expr_id = create_ast_variable_expr(ctx, variable_id); let call_expr = &mut ctx.heap[get_call_expr_id]; call_expr.arguments[0] = variable_expr_id.upcast(); transformed_stmts.push(variable_decl_stmt_id.upcast().upcast()); locals.push(variable_id); } // Insert runtime calls that facilitate the semantics of the select // block. // Create the call that indicates the start of the select block { let num_cases_expression_id = create_ast_literal_integer_expr(ctx, total_num_cases as u64); let num_ports_expression_id = create_ast_literal_integer_expr(ctx, total_num_ports as u64); let arguments = vec![ num_cases_expression_id.upcast(), num_ports_expression_id.upcast() ]; let call_expression_id = create_ast_call_expr(ctx, Method::SelectStart, &mut self.expression_buffer, arguments); let call_statement_id = create_ast_expression_stmt(ctx, call_expression_id.upcast()); transformed_stmts.push(call_statement_id.upcast()); } // Create calls for each select case that will register the ports that // we are waiting on at the runtime. { let mut total_port_index = 0; for case_index in 0..total_num_cases { let case = &ctx.heap[id].cases[case_index]; let case_num_ports = case.involved_ports.len(); for case_port_index in 0..case_num_ports { // Arguments to runtime call let port_variable_id = locals[total_port_index]; // so far this variable contains the temporary variables for the port expressions let case_index_expr_id = create_ast_literal_integer_expr(ctx, case_index as u64); let port_index_expr_id = create_ast_literal_integer_expr(ctx, case_port_index as u64); let port_variable_expr_id = create_ast_variable_expr(ctx, port_variable_id); let runtime_call_arguments = vec![ case_index_expr_id.upcast(), port_index_expr_id.upcast(), port_variable_expr_id.upcast() ]; // Create runtime call, then store it let runtime_call_expr_id = create_ast_call_expr(ctx, Method::SelectRegisterCasePort, &mut self.expression_buffer, runtime_call_arguments); let runtime_call_stmt_id = create_ast_expression_stmt(ctx, runtime_call_expr_id.upcast()); transformed_stmts.push(runtime_call_stmt_id.upcast()); total_port_index += 1; } } } // Create the variable that will hold the result of a completed select // block. Then create the runtime call that will produce this result let select_variable_id = create_ast_variable(ctx); locals.push(select_variable_id); { let runtime_call_expr_id = create_ast_call_expr(ctx, Method::SelectWait, &mut self.expression_buffer, Vec::new()); let variable_stmt_id = create_ast_variable_declaration_stmt(ctx, select_variable_id, runtime_call_expr_id.upcast()); transformed_stmts.push(variable_stmt_id.upcast().upcast()); } call_id_section.forget(); expr_id_section.forget(); // Precreate the block statement that will be the replacement of the // select statement. Do not set its members yet. let replacement_stmt_id = ctx.heap.alloc_block_statement(|this| BlockStatement{ this, is_implicit: true, span: InputSpan::new(), statements: Vec::new(), end_block: EndBlockStatementId::new_invalid(), scope: ScopeId::new_invalid(), next: StatementId::new_invalid(), }); let end_block_id = ctx.heap.alloc_end_block_statement(|this| EndBlockStatement{ this, start_block: replacement_stmt_id, next: stmt_id_after_select_stmt, }); // Now we transform each of the select block case's guard and code into // a chained if-else statement. if total_num_cases > 0 { let (if_stmt_id, end_if_stmt_id) = transform_select_case_code(ctx, id, 0, select_variable_id); let mut last_if_stmt_id = if_stmt_id; let mut last_end_if_stmt_id = end_if_stmt_id; transformed_stmts.push(last_if_stmt_id.upcast()); for case_index in 1..total_num_cases { let (if_stmt_id, end_if_stmt_id) = transform_select_case_code(ctx, id, case_index, select_variable_id); let last_if_stmt = &mut ctx.heap[last_if_stmt_id]; // last_if_stmt.false_case = Some(if_stmt_id.upcast()); // TODO: // 1. Change scoping such that it is a separate datastructure with separate IDs // 2. Change statements that contain "implicit scopes" to explicitly point to the appropriate scopes // 3. Continue here setting the true-body and false-body. // 4. Figure out how we're going to link everything up again } } // let block = ctx.heap.alloc_block_statement(|this| BlockStatement{ // this, // is_implicit: true, // span: stmt.span, // statements: vec![], // end_block: EndBlockStatementId(), // scope_node: ScopeNode {}, // first_unique_id_in_scope: 0, // next_unique_id_in_scope: 0, // locals, // labels: vec![], // next: () // }); return Ok(()) } } impl PassRewriting { fn create_runtime_call_statement(&self, ctx: &mut Ctx, method: Method, arguments: Vec) -> (CallExpressionId, ExpressionStatementId) { let call_expr_id = ctx.heap.alloc_call_expression(|this| CallExpression{ this, func_span: InputSpan::new(), full_span: InputSpan::new(), parser_type: ParserType{ elements: Vec::new(), full_span: InputSpan::new(), }, method, arguments, definition: DefinitionId::new_invalid(), parent: ExpressionParent::None, unique_id_in_definition: -1, }); let call_stmt_id = ctx.heap.alloc_expression_statement(|this| ExpressionStatement{ this, span: InputSpan::new(), expression: call_expr_id.upcast(), next: StatementId::new_invalid(), }); let call_expr = &mut ctx.heap[call_expr_id]; call_expr.parent = ExpressionParent::ExpressionStmt(call_stmt_id); return (call_expr_id, call_stmt_id); } fn create_runtime_select_wait_variable_and_statement(&self, ctx: &mut Ctx) -> (VariableId, MemoryStatementId) { let variable_id = create_ast_variable(ctx); let variable_expr_id = create_ast_variable_expr(ctx, variable_id); let runtime_call_expr_id = ctx.heap.alloc_call_expression(|this| CallExpression{ this, func_span: InputSpan::new(), full_span: InputSpan::new(), parser_type: ParserType{ elements: Vec::new(), full_span: InputSpan::new(), }, method: Method::SelectWait, arguments: Vec::new(), definition: DefinitionId::new_invalid(), parent: ExpressionParent::None, unique_id_in_definition: -1 }); let initial_expr_id = ctx.heap.alloc_assignment_expression(|this| AssignmentExpression{ this, operator_span: InputSpan::new(), full_span: InputSpan::new(), left: variable_expr_id.upcast(), operation: AssignmentOperator::Set, right: runtime_call_expr_id.upcast(), parent: ExpressionParent::None, unique_id_in_definition: -1 }); let variable_statement_id = ctx.heap.alloc_memory_statement(|this| MemoryStatement{ this, span: InputSpan::new(), variable: variable_id, initial_expr: initial_expr_id, next: StatementId::new_invalid() }); let variable_expr = &mut ctx.heap[variable_expr_id]; variable_expr.parent = ExpressionParent::Expression(initial_expr_id.upcast(), 0); let runtime_call_expr = &mut ctx.heap[runtime_call_expr_id]; runtime_call_expr.parent = ExpressionParent::Expression(initial_expr_id.upcast(), 1); let initial_expr = &mut ctx.heap[initial_expr_id]; initial_expr.parent = ExpressionParent::Memory(variable_statement_id); return (variable_id, variable_statement_id); } } // ----------------------------------------------------------------------------- // Utilities to create compiler-generated AST nodes // ----------------------------------------------------------------------------- fn create_ast_variable(ctx: &mut Ctx) -> VariableId { return ctx.heap.alloc_variable(|this| Variable{ this, kind: VariableKind::Local, parser_type: ParserType{ elements: Vec::new(), full_span: InputSpan::new(), }, identifier: Identifier::new_empty(InputSpan::new()), relative_pos_in_parent: -1, unique_id_in_scope: -1, }); } fn create_ast_variable_expr(ctx: &mut Ctx, variable_id: VariableId) -> VariableExpressionId { return ctx.heap.alloc_variable_expression(|this| VariableExpression{ this, identifier: Identifier::new_empty(InputSpan::new()), declaration: Some(variable_id), used_as_binding_target: false, parent: ExpressionParent::None, unique_id_in_definition: -1 }); } fn create_ast_call_expr(ctx: &mut Ctx, method: Method, buffer: &mut ScopedBuffer, arguments: Vec) -> CallExpressionId { let expression_ids = buffer.start_section_initialized(&arguments); let call_expression_id = ctx.heap.alloc_call_expression(|this| CallExpression{ this, func_span: InputSpan::new(), full_span: InputSpan::new(), parser_type: ParserType{ elements: Vec::new(), full_span: InputSpan::new(), }, method, arguments, definition: DefinitionId::new_invalid(), parent: ExpressionParent::None, unique_id_in_definition: -1, }); for argument_index in 0..expression_ids.len() { let argument_id = expression_ids[argument_index]; let argument_expr = &mut ctx.heap[argument_id]; *argument_expr.parent_mut() = ExpressionParent::Expression(call_expression_id.upcast(), argument_index as u32); } return call_expression_id; } fn create_ast_literal_integer_expr(ctx: &mut Ctx, unsigned_value: u64) -> LiteralExpressionId { return ctx.heap.alloc_literal_expression(|this| LiteralExpression{ this, span: InputSpan::new(), value: Literal::Integer(LiteralInteger{ unsigned_value, negated: false, }), parent: ExpressionParent::None, unique_id_in_definition: -1 }); } fn create_ast_equality_comparison_expr(ctx: &mut Ctx, variable_id: VariableId, value: u64) -> BinaryExpressionId { let var_expr_id = create_ast_variable_expr(ctx, variable_id); let int_expr_id = create_ast_literal_integer_expr(ctx, value); let cmp_expr_id = ctx.heap.alloc_binary_expression(|this| BinaryExpression{ this, operator_span: InputSpan::new(), full_span: InputSpan::new(), left: var_expr_id.upcast(), operation: BinaryOperator::Equality, right: int_expr_id.upcast(), parent: ExpressionParent::None, unique_id_in_definition: -1, }); let var_expr = &mut ctx.heap[var_expr_id]; var_expr.parent = ExpressionParent::Expression(cmp_expr_id.upcast(), 0); let int_expr = &mut ctx.heap[int_expr_id]; int_expr.parent = ExpressionParent::Expression(cmp_expr_id.upcast(), 1); return cmp_expr_id; } fn create_ast_expression_stmt(ctx: &mut Ctx, expression_id: ExpressionId) -> ExpressionStatementId { let statement_id = ctx.heap.alloc_expression_statement(|this| ExpressionStatement{ this, span: InputSpan::new(), expression: expression_id, next: StatementId::new_invalid(), }); let expression = &mut ctx.heap[expression_id]; *expression.parent_mut() = ExpressionParent::ExpressionStmt(statement_id); return statement_id; } fn create_ast_variable_declaration_stmt(ctx: &mut Ctx, variable_id: VariableId, initial_value_expr_id: ExpressionId) -> MemoryStatementId { // Create the assignment expression, assigning the initial value to the variable let variable_expr_id = create_ast_variable_expr(ctx, variable_id); let assignment_expr_id = ctx.heap.alloc_assignment_expression(|this| AssignmentExpression{ this, operator_span: InputSpan::new(), full_span: InputSpan::new(), left: variable_expr_id.upcast(), operation: AssignmentOperator::Set, right: initial_value_expr_id, parent: ExpressionParent::None, unique_id_in_definition: -1, }); // Create the memory statement let memory_stmt_id = ctx.heap.alloc_memory_statement(|this| MemoryStatement{ this, span: InputSpan::new(), variable: variable_id, initial_expr: assignment_expr_id, next: StatementId::new_invalid(), }); // Set all parents which we can access let variable_expr = &mut ctx.heap[variable_expr_id]; variable_expr.parent = ExpressionParent::Expression(assignment_expr_id.upcast(), 0); let value_expr = &mut ctx.heap[initial_value_expr_id]; *value_expr.parent_mut() = ExpressionParent::Expression(assignment_expr_id.upcast(), 1); let assignment_expr = &mut ctx.heap[assignment_expr_id]; assignment_expr.parent = ExpressionParent::Memory(memory_stmt_id); return memory_stmt_id; } fn create_ast_if_stmt(ctx: &mut Ctx, condition_expression_id: ExpressionId, true_case: IfStatementCase, false_case: Option) -> (IfStatementId, EndIfStatementId) { // Create if statement and the end-if statement let if_stmt_id = ctx.heap.alloc_if_statement(|this| IfStatement{ this, span: InputSpan::new(), test: condition_expression_id, true_case, false_case, end_if: EndIfStatementId::new_invalid() }); let end_if_stmt_id = ctx.heap.alloc_end_if_statement(|this| EndIfStatement{ this, start_if: if_stmt_id, next: StatementId::new_invalid(), }); // Link the statements up as much as we can let if_stmt = &mut ctx.heap[if_stmt_id]; if_stmt.end_if = end_if_stmt_id; let condition_expr = &mut ctx.heap[condition_expression_id]; *condition_expr.parent_mut() = ExpressionParent::If(if_stmt_id); return (if_stmt_id, end_if_stmt_id); } fn set_ast_if_statement_false_body(ctx: &mut Ctx, if_statement_id: IfStatementId, end_if_statement_id: EndIfStatementId, false_body_id: StatementId) { // Point if-statement to "false body" todo!("set scopes"); let if_stmt = &mut ctx.heap[if_statement_id]; debug_assert!(if_stmt.false_case.is_none()); // simplifies logic, not necessary if_stmt.false_case = Some(IfStatementCase{ body: false_body_id, scope: ScopeId::new_invalid(), }); // Point end of false body to the end of the if statement set_ast_statement_next(ctx, false_body_id, end_if_statement_id.upcast()); } /// Sets the specified AST statement's control flow such that it will be /// followed by the target statement. This may seem obvious, but may imply that /// a statement associated with, but different from, the source statement is /// modified. fn set_ast_statement_next(ctx: &mut Ctx, source_stmt_id: StatementId, target_stmt_id: StatementId) { let source_stmt = &mut ctx.heap[source_stmt_id]; match source_stmt { Statement::Block(stmt) => { let end_id = stmt.end_block; ctx.heap[end_id].next = target_stmt_id }, Statement::EndBlock(stmt) => stmt.next = target_stmt_id, Statement::Local(stmt) => { match stmt { LocalStatement::Memory(stmt) => stmt.next = target_stmt_id, LocalStatement::Channel(stmt) => stmt.next = target_stmt_id, } }, Statement::Labeled(stmt) => { let body_id = stmt.body; set_ast_statement_next(ctx, body_id, target_stmt_id); }, Statement::If(stmt) => { let end_id = stmt.end_if; ctx.heap[end_id].next = target_stmt_id; }, Statement::EndIf(stmt) => stmt.next = target_stmt_id, Statement::While(stmt) => { let end_id = stmt.end_while; ctx.heap[end_id].next = target_stmt_id; }, Statement::EndWhile(stmt) => stmt.next = target_stmt_id, Statement::Break(_stmt) => {}, Statement::Continue(_stmt) => {}, Statement::Synchronous(stmt) => { let end_id = stmt.end_sync; ctx.heap[end_id].next = target_stmt_id; }, Statement::EndSynchronous(stmt) => { stmt.next = target_stmt_id; }, Statement::Fork(_) | Statement::EndFork(_) => { todo!("remove fork from language"); }, Statement::Select(stmt) => { let end_id = stmt.end_select; ctx.heap[end_id].next = target_stmt_id; }, Statement::EndSelect(stmt) => stmt.next = target_stmt_id, Statement::Return(_stmt) => {}, Statement::Goto(_stmt) => {}, Statement::New(stmt) => stmt.next = target_stmt_id, Statement::Expression(stmt) => stmt.next = target_stmt_id, } }