diff --git a/src/protocol/parser/pass_rewriting.rs b/src/protocol/parser/pass_rewriting.rs index 517de373802e3d79c4936c3afa27b312898e606b..80554d91b426312cf0301354c1b037d4f10e5f09 100644 --- a/src/protocol/parser/pass_rewriting.rs +++ b/src/protocol/parser/pass_rewriting.rs @@ -59,12 +59,12 @@ impl Visitor for PassRewriting { fn visit_if_stmt(&mut self, ctx: &mut Ctx, id: IfStatementId) -> VisitorResult { let if_stmt = &ctx.heap[id]; - let true_body_id = if_stmt.true_body; - let false_body_id = if_stmt.false_body; + let true_body_id = if_stmt.true_case; + let false_body_id = if_stmt.false_case; - self.visit_block_stmt(ctx, true_body_id)?; + self.visit_stmt(ctx, true_body_id.body)?; if let Some(false_body_id) = false_body_id { - self.visit_block_stmt(ctx, false_body_id)?; + self.visit_stmt(ctx, false_body_id.body)?; } return Ok(()) @@ -73,37 +73,38 @@ impl Visitor for PassRewriting { fn visit_while_stmt(&mut self, ctx: &mut Ctx, id: WhileStatementId) -> VisitorResult { let while_stmt = &ctx.heap[id]; let body_id = while_stmt.body; - return self.visit_block_stmt(ctx, body_id); + return self.visit_stmt(ctx, body_id); } fn visit_synchronous_stmt(&mut self, ctx: &mut Ctx, id: SynchronousStatementId) -> VisitorResult { let sync_stmt = &ctx.heap[id]; let body_id = sync_stmt.body; - return self.visit_block_stmt(ctx, body_id); + return self.visit_stmt(ctx, body_id); } // --- Visiting the select statement fn visit_select_stmt(&mut self, ctx: &mut Ctx, id: SelectStatementId) -> VisitorResult { - // Utility for the last stage of rewriting process + // Utility for the last stage of rewriting process. Note that caller + // still needs to point the end of the if-statement to the end of the + // replacement statement of the select statement. fn transform_select_case_code(ctx: &mut Ctx, select_id: SelectStatementId, case_index: usize, select_var_id: VariableId) -> (IfStatementId, EndIfStatementId) { // Retrieve statement IDs associated with case let case = &ctx.heap[select_id].cases[case_index]; let case_guard_id = case.guard; - let case_body_id = case.block; + let case_body_id = case.body; + let case_scope_id = case.scope; // Create the if-statement for the result of the select statement let compare_expr_id = create_ast_equality_comparison_expr(ctx, select_var_id, case_index as u64); - let (if_stmt_id, end_if_stmt_id) = create_ast_if_stmt(ctx, compare_expr_id.upcast(), case_body_id, None); - - // Modify body of case to link up to the surrounding statements - // correctly - let case_body = &mut ctx.heap[case_body_id]; - let case_end_body_id = case_body.end_block; - case_body.statements.insert(0, case_guard_id); + let true_case = IfStatementCase{ + body: case_guard_id, // which is linked up to the body + scope: case_scope_id, + }; + let (if_stmt_id, end_if_stmt_id) = create_ast_if_stmt(ctx, compare_expr_id.upcast(), true_case, None); - let case_end_body = &mut ctx.heap[case_end_body_id]; - case_end_body.next = end_if_stmt_id.upcast(); + // Link up body statement to end-if + set_ast_statement_next(ctx, case_body_id, end_if_stmt_id.upcast()); return (if_stmt_id, end_if_stmt_id) } @@ -112,7 +113,7 @@ impl Visitor for PassRewriting { // containing builtin runtime-calls. And to do so we create temporary // variables and move some other statements around. let select_stmt = &ctx.heap[id]; - let mut total_num_cases = select_stmt.cases.len(); + let total_num_cases = select_stmt.cases.len(); let mut total_num_ports = 0; let end_select_stmt_id = select_stmt.end_select; let end_select = &ctx.heap[end_select_stmt_id]; @@ -165,7 +166,7 @@ impl Visitor for PassRewriting { num_ports_expression_id.upcast() ]; - let call_expression_id = create_ast_call_expr(ctx, Method::SelectStart, arguments); + let call_expression_id = create_ast_call_expr(ctx, Method::SelectStart, &mut self.expression_buffer, arguments); let call_statement_id = create_ast_expression_stmt(ctx, call_expression_id.upcast()); transformed_stmts.push(call_statement_id.upcast()); @@ -192,7 +193,7 @@ impl Visitor for PassRewriting { ]; // Create runtime call, then store it - let runtime_call_expr_id = create_ast_call_expr(ctx, Method::SelectRegisterCasePort, runtime_call_arguments); + let runtime_call_expr_id = create_ast_call_expr(ctx, Method::SelectRegisterCasePort, &mut self.expression_buffer, runtime_call_arguments); let runtime_call_stmt_id = create_ast_expression_stmt(ctx, runtime_call_expr_id.upcast()); transformed_stmts.push(runtime_call_stmt_id.upcast()); @@ -208,7 +209,7 @@ impl Visitor for PassRewriting { locals.push(select_variable_id); { - let runtime_call_expr_id = create_ast_call_expr(ctx, Method::SelectWait, Vec::new()); + let runtime_call_expr_id = create_ast_call_expr(ctx, Method::SelectWait, &mut self.expression_buffer, Vec::new()); let variable_stmt_id = create_ast_variable_declaration_stmt(ctx, select_variable_id, runtime_call_expr_id.upcast()); transformed_stmts.push(variable_stmt_id.upcast().upcast()); } @@ -224,11 +225,7 @@ impl Visitor for PassRewriting { span: InputSpan::new(), statements: Vec::new(), end_block: EndBlockStatementId::new_invalid(), - scope_node: ScopeNode::new_invalid(), - first_unique_id_in_scope: -1, - next_unique_id_in_scope: -1, - locals: Vec::new(), - labels: Vec::new(), + scope: ScopeId::new_invalid(), next: StatementId::new_invalid(), }); let end_block_id = ctx.heap.alloc_end_block_statement(|this| EndBlockStatement{ @@ -248,7 +245,7 @@ impl Visitor for PassRewriting { for case_index in 1..total_num_cases { let (if_stmt_id, end_if_stmt_id) = transform_select_case_code(ctx, id, case_index, select_variable_id); let last_if_stmt = &mut ctx.heap[last_if_stmt_id]; - last_if_stmt.false_body = Some(if_stmt_id.upcast()); + // last_if_stmt.false_case = Some(if_stmt_id.upcast()); // TODO: // 1. Change scoping such that it is a separate datastructure with separate IDs @@ -365,7 +362,7 @@ fn create_ast_variable(ctx: &mut Ctx) -> VariableId { full_span: InputSpan::new(), }, identifier: Identifier::new_empty(InputSpan::new()), - relative_pos_in_block: -1, + relative_pos_in_parent: -1, unique_id_in_scope: -1, }); } @@ -381,7 +378,8 @@ fn create_ast_variable_expr(ctx: &mut Ctx, variable_id: VariableId) -> VariableE }); } -fn create_ast_call_expr(ctx: &mut Ctx, method: Method, arguments: Vec) -> CallExpressionId { +fn create_ast_call_expr(ctx: &mut Ctx, method: Method, buffer: &mut ScopedBuffer, arguments: Vec) -> CallExpressionId { + let expression_ids = buffer.start_section_initialized(&arguments); let call_expression_id = ctx.heap.alloc_call_expression(|this| CallExpression{ this, func_span: InputSpan::new(), @@ -397,9 +395,10 @@ fn create_ast_call_expr(ctx: &mut Ctx, method: Method, arguments: Vec) -> (IfStatementId, EndIfStatementId) { +fn create_ast_if_stmt(ctx: &mut Ctx, condition_expression_id: ExpressionId, true_case: IfStatementCase, false_case: Option) -> (IfStatementId, EndIfStatementId) { // Create if statement and the end-if statement let if_stmt_id = ctx.heap.alloc_if_statement(|this| IfStatement{ this, span: InputSpan::new(), test: condition_expression_id, - true_body, - false_body, + true_case, + false_case, end_if: EndIfStatementId::new_invalid() }); @@ -512,27 +511,76 @@ fn create_ast_if_stmt(ctx: &mut Ctx, condition_expression_id: ExpressionId, true let condition_expr = &mut ctx.heap[condition_expression_id]; *condition_expr.parent_mut() = ExpressionParent::If(if_stmt_id); - let true_body_stmt = &ctx.heap[true_body]; - let true_body_end_stmt = &mut ctx.heap[true_body_stmt.end_block]; - true_body_end_stmt.next = end_if_stmt_id.upcast(); - - if let Some(false_body) = false_body { - let false_body_stmt = &ctx.heap[false_body]; - let false_body_end_stmt = &mut ctx.heap[false_body_stmt.end_block]; - false_body_end_stmt.next = end_if_stmt_id.upcast(); - } - return (if_stmt_id, end_if_stmt_id); } -fn set_ast_if_statement_false_body(ctx: &mut Ctx, if_statement_id: IfStatementId, end_if_statement_id: EndIfStatementId, false_body: BlockStatementId) { +fn set_ast_if_statement_false_body(ctx: &mut Ctx, if_statement_id: IfStatementId, end_if_statement_id: EndIfStatementId, false_body_id: StatementId) { // Point if-statement to "false body" + todo!("set scopes"); let if_stmt = &mut ctx.heap[if_statement_id]; - debug_assert!(if_stmt.false_body.is_none()); // simplifies logic, not necessary - if_stmt.false_body = Some(false_body); + debug_assert!(if_stmt.false_case.is_none()); // simplifies logic, not necessary + if_stmt.false_case = Some(IfStatementCase{ + body: false_body_id, + scope: ScopeId::new_invalid(), + }); // Point end of false body to the end of the if statement - let false_body_stmt = &ctx.heap[false_body]; - let false_body_end_stmt = &mut ctx.heap[false_body_stmt.end_block]; - false_body_end_stmt.next = end_if_statement_id.upcast(); + set_ast_statement_next(ctx, false_body_id, end_if_statement_id.upcast()); +} + +/// Sets the specified AST statement's control flow such that it will be +/// followed by the target statement. This may seem obvious, but may imply that +/// a statement associated with, but different from, the source statement is +/// modified. +fn set_ast_statement_next(ctx: &mut Ctx, source_stmt_id: StatementId, target_stmt_id: StatementId) { + let source_stmt = &mut ctx.heap[source_stmt_id]; + match source_stmt { + Statement::Block(stmt) => { + let end_id = stmt.end_block; + ctx.heap[end_id].next = target_stmt_id + }, + Statement::EndBlock(stmt) => stmt.next = target_stmt_id, + Statement::Local(stmt) => { + match stmt { + LocalStatement::Memory(stmt) => stmt.next = target_stmt_id, + LocalStatement::Channel(stmt) => stmt.next = target_stmt_id, + } + }, + Statement::Labeled(stmt) => { + let body_id = stmt.body; + set_ast_statement_next(ctx, body_id, target_stmt_id); + }, + Statement::If(stmt) => { + let end_id = stmt.end_if; + ctx.heap[end_id].next = target_stmt_id; + }, + Statement::EndIf(stmt) => stmt.next = target_stmt_id, + Statement::While(stmt) => { + let end_id = stmt.end_while; + ctx.heap[end_id].next = target_stmt_id; + }, + Statement::EndWhile(stmt) => stmt.next = target_stmt_id, + + Statement::Break(_stmt) => {}, + Statement::Continue(_stmt) => {}, + Statement::Synchronous(stmt) => { + let end_id = stmt.end_sync; + ctx.heap[end_id].next = target_stmt_id; + }, + Statement::EndSynchronous(stmt) => { + stmt.next = target_stmt_id; + }, + Statement::Fork(_) | Statement::EndFork(_) => { + todo!("remove fork from language"); + }, + Statement::Select(stmt) => { + let end_id = stmt.end_select; + ctx.heap[end_id].next = target_stmt_id; + }, + Statement::EndSelect(stmt) => stmt.next = target_stmt_id, + Statement::Return(_stmt) => {}, + Statement::Goto(_stmt) => {}, + Statement::New(stmt) => stmt.next = target_stmt_id, + Statement::Expression(stmt) => stmt.next = target_stmt_id, + } } \ No newline at end of file