Changeset - a2696db4bb8e
[Not reviewed]
0 1 0
MH - 3 years ago 2022-02-25 12:48:37
contact@maxhenger.nl
Add typing to AST transformation
1 file changed with 90 insertions and 31 deletions:
0 comments (0 inline, 0 general)
src/protocol/parser/pass_rewriting.rs
Show inline comments
 
@@ -5,6 +5,7 @@ use super::visitor::*;
 

	
 
pub(crate) struct PassRewriting {
 
    current_scope: ScopeId,
 
    current_procedure_id: ProcedureDefinitionId,
 
    definition_buffer: ScopedBuffer<DefinitionId>,
 
    statement_buffer: ScopedBuffer<StatementId>,
 
    call_expr_buffer: ScopedBuffer<CallExpressionId>,
 
@@ -16,6 +17,7 @@ impl PassRewriting {
 
    pub(crate) fn new() -> Self {
 
        Self{
 
            current_scope: ScopeId::new_invalid(),
 
            current_procedure_id: ProcedureDefinitionId::new_invalid(),
 
            definition_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_LARGE),
 
            statement_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL),
 
            call_expr_buffer: ScopedBuffer::with_capacity(BUFFER_INIT_CAP_SMALL),
 
@@ -49,6 +51,7 @@ impl Visitor for PassRewriting {
 
        let definition = &ctx.heap[id];
 
        let body_id = definition.body;
 
        self.current_scope = definition.scope;
 
        self.current_procedure_id = id;
 
        return self.visit_block_stmt(ctx, body_id);
 
    }
 

	
 
@@ -108,7 +111,11 @@ impl Visitor for PassRewriting {
 
        // Utility for the last stage of rewriting process. Note that caller
 
        // still needs to point the end of the if-statement to the end of the
 
        // replacement statement of the select statement.
 
        fn transform_select_case_code(ctx: &mut Ctx, select_id: SelectStatementId, case_index: usize, select_var_id: VariableId) -> (IfStatementId, EndIfStatementId) {
 
        fn transform_select_case_code(
 
            ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId,
 
            select_id: SelectStatementId, case_index: usize,
 
            select_var_id: VariableId, select_var_type_id: TypeIdReference
 
        ) -> (IfStatementId, EndIfStatementId) {
 
            // Retrieve statement IDs associated with case
 
            let case = &ctx.heap[select_id].cases[case_index];
 
            let case_guard_id = case.guard;
 
@@ -116,7 +123,7 @@ impl Visitor for PassRewriting {
 
            let case_scope_id = case.scope;
 

	
 
            // Create the if-statement for the result of the select statement
 
            let compare_expr_id = create_ast_equality_comparison_expr(ctx, select_var_id, case_index as u64);
 
            let compare_expr_id = create_ast_equality_comparison_expr(ctx, containing_procedure_id, select_var_id, select_var_type_id, case_index as u64);
 
            let true_case = IfStatementCase{
 
                body: case_guard_id, // which is linked up to the body
 
                scope: case_scope_id,
 
@@ -177,19 +184,21 @@ impl Visitor for PassRewriting {
 
        for port_var_idx in 0..call_id_section.len() {
 
            let get_call_expr_id = call_id_section[port_var_idx];
 
            let port_expr_id = expr_id_section[port_var_idx];
 
            let port_type_index = ctx.heap[port_expr_id].type_index();
 
            let port_type_ref = TypeIdReference::IndirectSameAsExpr(port_type_index);
 

	
 
            // Move the port expression such that it gets assigned to a temporary variable
 
            let variable_id = create_ast_variable(ctx, outer_scope_id);
 
            let variable_decl_stmt_id = create_ast_variable_declaration_stmt(ctx, variable_id, port_expr_id);
 
            let variable_decl_stmt_id = create_ast_variable_declaration_stmt(ctx, self.current_procedure_id, variable_id, port_type_ref, port_expr_id);
 

	
 
            // Replace the original port expression in the call with a reference
 
            // to the replacement variable
 
            let variable_expr_id = create_ast_variable_expr(ctx, variable_id);
 
            let variable_expr_id = create_ast_variable_expr(ctx, self.current_procedure_id, variable_id, port_type_ref);
 
            let call_expr = &mut ctx.heap[get_call_expr_id];
 
            call_expr.arguments[0] = variable_expr_id.upcast();
 

	
 
            transformed_stmts.push(variable_decl_stmt_id.upcast().upcast());
 
            locals.push(variable_id);
 
            locals.push((variable_id, port_type_ref));
 
        }
 

	
 
        // Insert runtime calls that facilitate the semantics of the select
 
@@ -197,14 +206,14 @@ impl Visitor for PassRewriting {
 

	
 
        // Create the call that indicates the start of the select block
 
        {
 
            let num_cases_expression_id = create_ast_literal_integer_expr(ctx, total_num_cases as u64);
 
            let num_ports_expression_id = create_ast_literal_integer_expr(ctx, total_num_ports as u64);
 
            let num_cases_expression_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, total_num_cases as u64);
 
            let num_ports_expression_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, total_num_ports as u64);
 
            let arguments = vec![
 
                num_cases_expression_id.upcast(),
 
                num_ports_expression_id.upcast()
 
            ];
 

	
 
            let call_expression_id = create_ast_call_expr(ctx, Method::SelectStart, &mut self.expression_buffer, arguments);
 
            let call_expression_id = create_ast_call_expr(ctx, self.current_procedure_id, Method::SelectStart, &mut self.expression_buffer, arguments);
 
            let call_statement_id = create_ast_expression_stmt(ctx, call_expression_id.upcast());
 

	
 
            transformed_stmts.push(call_statement_id.upcast());
 
@@ -220,10 +229,10 @@ impl Visitor for PassRewriting {
 

	
 
                for case_port_index in 0..case_num_ports {
 
                    // Arguments to runtime call
 
                    let port_variable_id = locals[total_port_index]; // so far this variable contains the temporary variables for the port expressions
 
                    let case_index_expr_id = create_ast_literal_integer_expr(ctx, case_index as u64);
 
                    let port_index_expr_id = create_ast_literal_integer_expr(ctx, case_port_index as u64);
 
                    let port_variable_expr_id = create_ast_variable_expr(ctx, port_variable_id);
 
                    let (port_variable_id, port_variable_type) = locals[total_port_index]; // so far this variable contains the temporary variables for the port expressions
 
                    let case_index_expr_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, case_index as u64);
 
                    let port_index_expr_id = create_ast_literal_integer_expr(ctx, self.current_procedure_id, case_port_index as u64);
 
                    let port_variable_expr_id = create_ast_variable_expr(ctx, self.current_procedure_id, port_variable_id, port_variable_type);
 
                    let runtime_call_arguments = vec![
 
                        case_index_expr_id.upcast(),
 
                        port_index_expr_id.upcast(),
 
@@ -231,7 +240,7 @@ impl Visitor for PassRewriting {
 
                    ];
 

	
 
                    // Create runtime call, then store it
 
                    let runtime_call_expr_id = create_ast_call_expr(ctx, Method::SelectRegisterCasePort, &mut self.expression_buffer, runtime_call_arguments);
 
                    let runtime_call_expr_id = create_ast_call_expr(ctx, self.current_procedure_id, Method::SelectRegisterCasePort, &mut self.expression_buffer, runtime_call_arguments);
 
                    let runtime_call_stmt_id = create_ast_expression_stmt(ctx, runtime_call_expr_id.upcast());
 

	
 
                    transformed_stmts.push(runtime_call_stmt_id.upcast());
 
@@ -244,11 +253,12 @@ impl Visitor for PassRewriting {
 
        // Create the variable that will hold the result of a completed select
 
        // block. Then create the runtime call that will produce this result
 
        let select_variable_id = create_ast_variable(ctx, outer_scope_id);
 
        locals.push(select_variable_id);
 
        let select_variable_type = TypeIdReference::DirectTypeId(ctx.arch.uint32_type_id);
 
        locals.push((select_variable_id, select_variable_type));
 

	
 
        {
 
            let runtime_call_expr_id = create_ast_call_expr(ctx, Method::SelectWait, &mut self.expression_buffer, Vec::new());
 
            let variable_stmt_id = create_ast_variable_declaration_stmt(ctx, select_variable_id, runtime_call_expr_id.upcast());
 
            let runtime_call_expr_id = create_ast_call_expr(ctx, self.current_procedure_id, Method::SelectWait, &mut self.expression_buffer, Vec::new());
 
            let variable_stmt_id = create_ast_variable_declaration_stmt(ctx, self.current_procedure_id, select_variable_id, select_variable_type, runtime_call_expr_id.upcast());
 
            transformed_stmts.push(variable_stmt_id.upcast().upcast());
 
        }
 

	
 
@@ -258,7 +268,7 @@ impl Visitor for PassRewriting {
 
        // Now we transform each of the select block case's guard and code into
 
        // a chained if-else statement.
 
        if total_num_cases > 0 {
 
            let (if_stmt_id, end_if_stmt_id) = transform_select_case_code(ctx, id, 0, select_variable_id);
 
            let (if_stmt_id, end_if_stmt_id) = transform_select_case_code(ctx, self.current_procedure_id, id, 0, select_variable_id, select_variable_type);
 
            let first_end_if_stmt = &mut ctx.heap[end_if_stmt_id];
 
            first_end_if_stmt.next = outer_end_block_id.upcast();
 

	
 
@@ -267,7 +277,7 @@ impl Visitor for PassRewriting {
 
            transformed_stmts.push(last_if_stmt_id.upcast());
 

	
 
            for case_index in 1..total_num_cases {
 
                let (if_stmt_id, end_if_stmt_id) = transform_select_case_code(ctx, id, case_index, select_variable_id);
 
                let (if_stmt_id, end_if_stmt_id) = transform_select_case_code(ctx, self.current_procedure_id, id, case_index, select_variable_id, select_variable_type);
 
                let false_case_scope_id = ctx.heap.alloc_scope(|this| Scope::new(this, ScopeAssociation::If(last_if_stmt_id, false)));
 
                set_ast_if_statement_false_body(ctx, last_if_stmt_id, last_end_if_stmt_id, IfStatementCase{ body: if_stmt_id.upcast(), scope: false_case_scope_id });
 

	
 
@@ -298,6 +308,12 @@ impl Visitor for PassRewriting {
 
// Utilities to create compiler-generated AST nodes
 
// -----------------------------------------------------------------------------
 

	
 
#[derive(Clone, Copy)]
 
enum TypeIdReference {
 
    DirectTypeId(TypeId),
 
    IndirectSameAsExpr(i32), // by type index
 
}
 

	
 
fn create_ast_variable(ctx: &mut Ctx, scope_id: ScopeId) -> VariableId {
 
    let variable_id = ctx.heap.alloc_variable(|this| Variable{
 
        this,
 
@@ -316,22 +332,31 @@ fn create_ast_variable(ctx: &mut Ctx, scope_id: ScopeId) -> VariableId {
 
    return variable_id;
 
}
 

	
 
fn create_ast_variable_expr(ctx: &mut Ctx, variable_id: VariableId) -> VariableExpressionId {
 
fn create_ast_variable_expr(ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, variable_id: VariableId, variable_type_id: TypeIdReference) -> VariableExpressionId {
 
    let variable_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, variable_type_id);
 
    return ctx.heap.alloc_variable_expression(|this| VariableExpression{
 
        this,
 
        identifier: Identifier::new_empty(InputSpan::new()),
 
        declaration: Some(variable_id),
 
        used_as_binding_target: false,
 
        parent: ExpressionParent::None,
 
        type_index: -1
 
        type_index: variable_type_index,
 
    });
 
}
 

	
 
fn create_ast_call_expr(ctx: &mut Ctx, method: Method, buffer: &mut ScopedBuffer<ExpressionId>, arguments: Vec<ExpressionId>) -> CallExpressionId {
 
fn create_ast_call_expr(ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, method: Method, buffer: &mut ScopedBuffer<ExpressionId>, arguments: Vec<ExpressionId>) -> CallExpressionId {
 
    let call_type_id = match method {
 
        Method::SelectStart => ctx.arch.void_type_id,
 
        Method::SelectRegisterCasePort => ctx.arch.void_type_id,
 
        Method::SelectWait => ctx.arch.uint32_type_id, // TODO: Not pretty, this. Pretty error prone
 
        _ => unreachable!(), // if this goes of, add the appropriate method here.
 
    };
 

	
 
    let expression_ids = buffer.start_section_initialized(&arguments);
 
    let call_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, TypeIdReference::DirectTypeId(call_type_id));
 
    let call_expression_id = ctx.heap.alloc_call_expression(|this| CallExpression{
 
        this,
 
        func_span: InputSpan::new(),
 
        this,
 
        full_span: InputSpan::new(),
 
        parser_type: ParserType{
 
            elements: Vec::new(),
 
@@ -341,7 +366,7 @@ fn create_ast_call_expr(ctx: &mut Ctx, method: Method, buffer: &mut ScopedBuffer
 
        arguments,
 
        procedure: ProcedureDefinitionId::new_invalid(),
 
        parent: ExpressionParent::None,
 
        type_index: -1,
 
        type_index: call_type_index,
 
    });
 

	
 
    for argument_index in 0..expression_ids.len() {
 
@@ -353,7 +378,8 @@ fn create_ast_call_expr(ctx: &mut Ctx, method: Method, buffer: &mut ScopedBuffer
 
    return call_expression_id;
 
}
 

	
 
fn create_ast_literal_integer_expr(ctx: &mut Ctx, unsigned_value: u64) -> LiteralExpressionId {
 
fn create_ast_literal_integer_expr(ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId, unsigned_value: u64) -> LiteralExpressionId {
 
    let literal_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, TypeIdReference::DirectTypeId(ctx.arch.uint64_type_id));
 
    return ctx.heap.alloc_literal_expression(|this| LiteralExpression{
 
        this,
 
        span: InputSpan::new(),
 
@@ -362,13 +388,17 @@ fn create_ast_literal_integer_expr(ctx: &mut Ctx, unsigned_value: u64) -> Litera
 
            negated: false,
 
        }),
 
        parent: ExpressionParent::None,
 
        type_index: -1
 
        type_index: literal_type_index,
 
    });
 
}
 

	
 
fn create_ast_equality_comparison_expr(ctx: &mut Ctx, variable_id: VariableId, value: u64) -> BinaryExpressionId {
 
    let var_expr_id = create_ast_variable_expr(ctx, variable_id);
 
    let int_expr_id = create_ast_literal_integer_expr(ctx, value);
 
fn create_ast_equality_comparison_expr(
 
    ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId,
 
    variable_id: VariableId, variable_type: TypeIdReference, value: u64
 
) -> BinaryExpressionId {
 
    let var_expr_id = create_ast_variable_expr(ctx, containing_procedure_id, variable_id, variable_type);
 
    let int_expr_id = create_ast_literal_integer_expr(ctx, containing_procedure_id, value);
 
    let cmp_type_index = add_new_procedure_expression_type(ctx, containing_procedure_id, TypeIdReference::DirectTypeId(ctx.arch.bool_type_id));
 
    let cmp_expr_id = ctx.heap.alloc_binary_expression(|this| BinaryExpression{
 
        this,
 
        operator_span: InputSpan::new(),
 
@@ -377,7 +407,7 @@ fn create_ast_equality_comparison_expr(ctx: &mut Ctx, variable_id: VariableId, v
 
        operation: BinaryOperator::Equality,
 
        right: int_expr_id.upcast(),
 
        parent: ExpressionParent::None,
 
        type_index: -1,
 
        type_index: cmp_type_index,
 
    });
 

	
 
    let var_expr = &mut ctx.heap[var_expr_id];
 
@@ -402,9 +432,12 @@ fn create_ast_expression_stmt(ctx: &mut Ctx, expression_id: ExpressionId) -> Exp
 
    return statement_id;
 
}
 

	
 
fn create_ast_variable_declaration_stmt(ctx: &mut Ctx, variable_id: VariableId, initial_value_expr_id: ExpressionId) -> MemoryStatementId {
 
fn create_ast_variable_declaration_stmt(
 
    ctx: &mut Ctx, containing_procedure_id: ProcedureDefinitionId,
 
    variable_id: VariableId, variable_type: TypeIdReference, initial_value_expr_id: ExpressionId
 
) -> MemoryStatementId {
 
    // Create the assignment expression, assigning the initial value to the variable
 
    let variable_expr_id = create_ast_variable_expr(ctx, variable_id);
 
    let variable_expr_id = create_ast_variable_expr(ctx, containing_procedure_id, variable_id, variable_type);
 
    let assignment_expr_id = ctx.heap.alloc_assignment_expression(|this| AssignmentExpression{
 
        this,
 
        operator_span: InputSpan::new(),
 
@@ -605,3 +638,29 @@ fn add_child_scope_to_parent(ctx: &mut Ctx, scope_buffer: &mut ScopedBuffer<Scop
 
    let parent_scope = &mut ctx.heap[parent_scope_id];
 
    parent_scope.nested.insert(insert_pos, child_scope_id);
 
}
 

	
 
fn add_new_procedure_expression_type(ctx: &mut Ctx, procedure_id: ProcedureDefinitionId, type_id: TypeIdReference) -> i32 {
 
    let procedure = &mut ctx.heap[procedure_id];
 
    let type_index = procedure.monomorphs[0].expr_info.len();
 

	
 
    match type_id {
 
        TypeIdReference::DirectTypeId(type_id) => {
 
            for monomorph in procedure.monomorphs.iter_mut() {
 
                debug_assert_eq!(monomorph.expr_info.len(), type_index);
 
                monomorph.expr_info.push(ExpressionInfo{
 
                    type_id,
 
                    variant: ExpressionInfoVariant::Generic
 
                });
 
            }
 
        },
 
        TypeIdReference::IndirectSameAsExpr(source_type_index) => {
 
            for monomorph in procedure.monomorphs.iter_mut() {
 
                debug_assert_eq!(monomorph.expr_info.len(), type_index);
 
                let copied_expr_info = monomorph.expr_info[source_type_index as usize];
 
                monomorph.expr_info.push(copied_expr_info)
 
            }
 
        }
 
    }
 

	
 
    return type_index as i32;
 
}
0 comments (0 inline, 0 general)