From a2bfd792a20295b95f43b776b2c4667dee31856b 2021-11-13 12:36:41 From: MH Date: 2021-11-13 12:36:41 Subject: [PATCH] Implement and test explicit forking in runtime --- diff --git a/src/protocol/ast_printer.rs b/src/protocol/ast_printer.rs index 4e566cdc23d634459b11530513dee8b50d57f704..6aff2810101c2a5143012ad83ec80eaa21cd852d 100644 --- a/src/protocol/ast_printer.rs +++ b/src/protocol/ast_printer.rs @@ -36,6 +36,8 @@ const PREFIX_BREAK_STMT_ID: &'static str = "SBre"; const PREFIX_CONTINUE_STMT_ID: &'static str = "SCon"; const PREFIX_SYNC_STMT_ID: &'static str = "SSyn"; const PREFIX_ENDSYNC_STMT_ID: &'static str = "SESy"; +const PREFIX_FORK_STMT_ID: &'static str = "SFrk"; +const PREFIX_END_FORK_STMT_ID: &'static str = "SEFk"; const PREFIX_RETURN_STMT_ID: &'static str = "SRet"; const PREFIX_ASSERT_STMT_ID: &'static str = "SAsr"; const PREFIX_GOTO_STMT_ID: &'static str = "SGot"; @@ -511,6 +513,24 @@ impl ASTWriter { self.kv(indent2).with_s_key("StartSync").with_disp_val(&stmt.start_sync.0.index); self.kv(indent2).with_s_key("Next").with_disp_val(&stmt.next.index); }, + Statement::Fork(stmt) => { + self.kv(indent).with_id(PREFIX_FORK_STMT_ID, stmt.this.0.index) + .with_s_key("Fork"); + self.kv(indent2).with_s_key("EndFork").with_disp_val(&stmt.end_fork.0.index); + self.kv(indent2).with_s_key("LeftBody"); + self.write_stmt(heap, stmt.left_body.upcast(), indent3); + + if let Some(right_body_id) = stmt.right_body { + self.kv(indent2).with_s_key("RightBody"); + self.write_stmt(heap, right_body_id.upcast(), indent3); + } + }, + Statement::EndFork(stmt) => { + self.kv(indent).with_id(PREFIX_END_FORK_STMT_ID, stmt.this.0.index) + .with_s_key("EndFork"); + self.kv(indent2).with_s_key("StartFork").with_disp_val(&stmt.start_fork.0.index); + self.kv(indent2).with_s_key("Next").with_disp_val(&stmt.next.index); + } Statement::Return(stmt) => { self.kv(indent).with_id(PREFIX_RETURN_STMT_ID, stmt.this.0.index) .with_s_key("Return"); diff --git a/src/runtime2/branch.rs b/src/runtime2/branch.rs index f3fa84c5ee325768f51d0ca69ce14f834f95e02e..a790ff23e52fc9ae7f37e70509474ac156234248 100644 --- a/src/runtime2/branch.rs +++ b/src/runtime2/branch.rs @@ -65,6 +65,7 @@ pub(crate) struct Branch { pub next_in_queue: BranchId, // used by `ExecTree`/`BranchQueue` pub inbox: HashMap, // TODO: Remove, currently only valid in single-get/put mode pub prepared_channel: Option<(Value, Value)>, // TODO: Maybe remove? + pub prepared_fork: Option, // TODO: See above } impl BranchListItem for Branch { @@ -85,17 +86,19 @@ impl Branch { next_in_queue: BranchId::new_invalid(), inbox: HashMap::new(), prepared_channel: None, + prepared_fork: None, } } /// Constructs a sync branch. The provided branch is assumed to be the /// parent of the new branch within the execution tree. fn new_sync(new_index: u32, parent_branch: &Branch) -> Self { - debug_assert!( - (parent_branch.sync_state == SpeculativeState::RunningNonSync && !parent_branch.parent_id.is_valid()) || - (parent_branch.sync_state == SpeculativeState::HaltedAtBranchPoint) - ); // forking from non-sync, or forking from a branching point + // debug_assert!( + // (parent_branch.sync_state == SpeculativeState::RunningNonSync && !parent_branch.parent_id.is_valid()) || + // (parent_branch.sync_state == SpeculativeState::HaltedAtBranchPoint) + // ); // forking from non-sync, or forking from a branching point debug_assert!(parent_branch.prepared_channel.is_none()); + debug_assert!(parent_branch.prepared_fork.is_none()); Branch { id: BranchId::new(new_index), @@ -106,6 +109,7 @@ impl Branch { next_in_queue: BranchId::new_invalid(), inbox: parent_branch.inbox.clone(), prepared_channel: None, + prepared_fork: None, } } diff --git a/src/runtime2/connector.rs b/src/runtime2/connector.rs index 5cd7bd6061b890d86093fffdebfb7e5975a95a9f..4f7c23baab3a6d06e4d7f391d0574ab5b733da77 100644 --- a/src/runtime2/connector.rs +++ b/src/runtime2/connector.rs @@ -75,6 +75,7 @@ struct ConnectorRunContext<'a> { received: &'a HashMap, scheduler: SchedulerCtx<'a>, prepared_channel: Option<(Value, Value)>, + prepared_fork: Option, } impl<'a> RunContext for ConnectorRunContext<'a>{ @@ -101,6 +102,10 @@ impl<'a> RunContext for ConnectorRunContext<'a>{ fn get_channel(&mut self) -> Option<(Value, Value)> { return self.prepared_channel.take(); } + + fn get_fork(&mut self) -> Option { + return self.prepared_fork.take(); + } } impl Connector for ConnectorPDL { @@ -210,6 +215,7 @@ impl ConnectorPDL { received: &branch.inbox, scheduler: sched_ctx, prepared_channel: branch.prepared_channel.take(), + prepared_fork: branch.prepared_fork.take(), }; let run_result = branch.code_state.run(&mut run_context, &sched_ctx.runtime.protocol_description); @@ -293,6 +299,21 @@ impl ConnectorPDL { branch.sync_state = SpeculativeState::Inconsistent; } }, + RunResult::BranchFork => { + // Like the `NewChannel` result. This means we're setting up + // a branch and putting a marker inside the RunContext for the + // next time we run the PDL code + let left_id = branch_id; + let right_id = self.tree.fork_branch(left_id); + self.consensus.notify_of_new_branch(left_id, right_id); + self.tree.push_into_queue(QueueKind::Runnable, left_id); + self.tree.push_into_queue(QueueKind::Runnable, right_id); + + let left_branch = &mut self.tree[left_id]; + left_branch.prepared_fork = Some(true); + let right_branch = &mut self.tree[right_id]; + right_branch.prepared_fork = Some(false); + } RunResult::BranchPut(port_id, content) => { // Branch is attempting to send data let port_id = PortIdLocal::new(port_id.0.u32_suffix); @@ -335,6 +356,7 @@ impl ConnectorPDL { received: &branch.inbox, scheduler: sched_ctx, prepared_channel: branch.prepared_channel.take(), + prepared_fork: branch.prepared_fork.take(), }; let run_result = branch.code_state.run(&mut run_context, &sched_ctx.runtime.protocol_description); diff --git a/src/runtime2/tests/api_component.rs b/src/runtime2/tests/api_component.rs index 28f7101520dec6d0d1c9075645339ecec5f161a5..67271d2986bb610f91a18425edff20de76b0b0f2 100644 --- a/src/runtime2/tests/api_component.rs +++ b/src/runtime2/tests/api_component.rs @@ -121,4 +121,38 @@ fn test_putting_to_component() { // Note: if we finish a round, then it must have succeeded :) api.wait().expect("finish sync round"); } +} + +#[test] +fn test_doing_nothing() { + const CODE: &'static str = " + primitive getter(in input, u32 num_loops) { + u32 index = 0; + while (index < num_loops) { + sync {} + sync { auto res = get(input); assert(res); } + index += 1; + } + } + "; + + let pd = ProtocolDescription::parse(CODE.as_bytes()).unwrap(); + let rt = Runtime::new(NUM_THREADS, pd); + let mut api = rt.create_interface(); + + let channel = api.create_channel().unwrap(); + api.create_connector("", "getter", ValueGroup::new_stack(vec![ + Value::Input(PortId::new(channel.getter_id.index)), + Value::UInt32(NUM_LOOPS), + ])).unwrap(); + + for _ in 0..NUM_LOOPS { + api.perform_sync_round(vec![]).expect("start silent sync round"); + api.wait().expect("finish silent sync round"); + api.perform_sync_round(vec![ + ApplicationSyncAction::Put(channel.putter_id, ValueGroup::new_stack(vec![Value::Bool(true)])) + ]).expect("start firing sync round"); + let res = api.wait().expect("finish firing sync round"); + assert!(res.is_empty()); + } } \ No newline at end of file diff --git a/src/runtime2/tests/mod.rs b/src/runtime2/tests/mod.rs index ccf488b87273ce44e6efc2e59e87dfbc47f26e4c..f31961a21b2eac002190fb3c8c9bb557a937082c 100644 --- a/src/runtime2/tests/mod.rs +++ b/src/runtime2/tests/mod.rs @@ -1,5 +1,6 @@ mod network_shapes; mod api_component; +mod speculation_basic; use super::*; use crate::{PortId, ProtocolDescription}; diff --git a/src/runtime2/tests/speculation_basic.rs b/src/runtime2/tests/speculation_basic.rs new file mode 100644 index 0000000000000000000000000000000000000000..c8522eb56d65352763b2e2f4f57dc142517fde7e --- /dev/null +++ b/src/runtime2/tests/speculation_basic.rs @@ -0,0 +1,69 @@ +// Testing speculation - Basic forms + +use super::*; + +#[test] +fn test_maybe_do_nothing() { + // Three variants in which the behaviour in which nothing is performed is + // somehow not allowed. Note that we "check" by seeing if the test finishes. + // Only the branches in which ports fire increment the loop index + const CODE: &'static str = " + primitive only_puts(out output, u32 num_loops) { + u32 index = 0; + while (index < num_loops) { + sync { put(output, true); } + index += 1; + } + } + + primitive might_put(out output, u32 num_loops) { + u32 index = 0; + while (index < num_loops) { + sync { + fork { put(output, true); index += 1; } + or {} + } + } + } + + primitive only_gets(in input, u32 num_loops) { + u32 index = 0; + while (index < num_loops) { + sync { auto res = get(input); assert(res); } + index += 1; + } + } + + primitive might_get(in input, u32 num_loops) { + u32 index = 0; + while (index < num_loops) { + sync fork { auto res = get(input); assert(res); index += 1; } or {} + } + } + "; + + // Construct all variants which should work and wait until the runtime exits + run_test_in_runtime(CODE, |api| { + // only putting -> maybe getting + let channel = api.create_channel().unwrap(); + api.create_connector("", "only_puts", ValueGroup::new_stack(vec![ + Value::Output(PortId::new(channel.putter_id.index)), + Value::UInt32(NUM_LOOPS), + ])); + api.create_connector("", "might_get", ValueGroup::new_stack(vec![ + Value::Input(PortId::new(channel.getter_id.index)), + Value::UInt32(NUM_LOOPS), + ])); + + // maybe putting -> only getting + let channel = api.create_channel().unwrap(); + api.create_connector("", "might_put", ValueGroup::new_stack(vec![ + Value::Output(PortId::new(channel.putter_id.index)), + Value::UInt32(NUM_LOOPS), + ])); + api.create_connector("", "only_gets", ValueGroup::new_stack(vec![ + Value::Input(PortId::new(channel.getter_id.index)), + Value::UInt32(NUM_LOOPS), + ])); + }) +} \ No newline at end of file