diff --git a/src/macros.rs b/src/macros.rs index 1d50f6571be4b313e50a9805c37d2cf393500373..e27cb8c2d16d10bef8c5165b96b8d6d091bde3fe 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -21,8 +21,9 @@ macro_rules! dbg_code { } // Given a function name, return type and variant, will generate the all-so -// common `union_value.as_variant()` method. -macro_rules! union_cast_method_impl { +// common `union_value.as_variant()` method. The return value is the reference +// to the embedded union type. +macro_rules! union_cast_to_ref_method_impl { ($func_name:ident, $ret_type:ty, $variant:path) => { fn $func_name(&self) -> &$ret_type { match self { @@ -31,4 +32,18 @@ macro_rules! union_cast_method_impl { } } } +} + +// Another union cast, but now returning a copy of the value +macro_rules! union_cast_to_value_method_impl { + ($func_name:ident, $ret_type:ty, $variant:path) => { + impl Value { + pub(crate) fn $func_name(&self) -> $ret_type { + match self { + $variant(v) => *v, + _ => unreachable!(), + } + } + } + } } \ No newline at end of file diff --git a/src/protocol/eval/executor.rs b/src/protocol/eval/executor.rs index 913be03fa88e408a2ac4bc3b71832f87bfbfaf8d..25782476114ff1097fca333d48ad04a7afb1c099 100644 --- a/src/protocol/eval/executor.rs +++ b/src/protocol/eval/executor.rs @@ -207,6 +207,9 @@ pub enum EvalContinuation { BlockFires(PortId), BlockGet(PortId), Put(PortId, ValueGroup), + SelectStart(u32, u32), // (num_cases, num_ports_total) + SelectRegisterPort(u32, u32, PortId), // (case_index, port_index_in_case, port_id) + SelectWait, // wait until select can continue // Returned only in non-sync mode ComponentTerminated, SyncBlockStart, @@ -652,12 +655,7 @@ impl Prompt { Method::Fires => { let port_value = cur_frame.expr_values.pop_front().unwrap(); let port_value_deref = self.store.maybe_read_ref(&port_value).clone(); - - let port_id = match port_value_deref { - Value::Input(port_id) => port_id, - Value::Output(port_id) => port_id, - _ => unreachable!("executor calling 'fires' on value {:?}", port_value_deref), - }; + let port_id = port_value_deref.as_port_id(); match ctx.fires(port_id) { None => { @@ -736,13 +734,28 @@ impl Prompt { println!("{}", message); }, Method::SelectStart => { - todo!("select start"); + let num_cases = self.store.maybe_read_ref(&cur_frame.expr_values.pop_front().unwrap()).as_uint32(); + let num_ports = self.store.maybe_read_ref(&cur_frame.expr_values.pop_front().unwrap()).as_uint32(); + if !ctx.select_start(num_cases, num_ports) { + return Ok(EvalContinuation::SelectStart(num_cases, num_ports)) + } }, Method::SelectRegisterCasePort => { - todo!("select register"); + let case_index = self.store.maybe_read_ref(&cur_frame.expr_values.pop_front().unwrap()).as_uint32(); + let port_index = self.store.maybe_read_ref(&cur_frame.expr_values.pop_front().unwrap()).as_uint32(); + let port_value = self.store.maybe_read_ref(&cur_frame.expr_values.pop_front().unwrap()).as_port_id(); + + if !ctx.performed_select_start() { + return Ok(EvalContinuation::SelectRegisterPort(case_index, port_index, port_value)); + } }, Method::SelectWait => { - todo!("select wait"); + match ctx.performed_select_wait() { + Some(select_index) => { + cur_frame.expr_values.push_back(Value::UInt32(select_index)); + }, + None => return Ok(EvalContinuation::SelectWait), + } }, Method::UserComponent => { // This is actually handled by the evaluation @@ -965,8 +978,12 @@ impl Prompt { Ok(EvalContinuation::Stepping) }, - Statement::Select(_stmt) => { - todo!("implement select evaluation") + Statement::Select(stmt) => { + // This is a trampoline for the statements that were placed by + // the AST transformation pass + cur_frame.position = stmt.next; + + Ok(EvalContinuation::Stepping) }, Statement::EndSelect(stmt) => { cur_frame.position = stmt.next; diff --git a/src/protocol/eval/value.rs b/src/protocol/eval/value.rs index 9d9d1736c3a51487dd77f3215782dd1149e0671b..d8bf773b7bc74426a37fa54ad573c4c0d6d8bd00 100644 --- a/src/protocol/eval/value.rs +++ b/src/protocol/eval/value.rs @@ -65,39 +65,26 @@ pub enum Value { Struct(HeapPos), } -macro_rules! impl_union_unpack_as_value { - ($func_name:ident, $variant_name:path, $return_type:ty) => { - impl Value { - pub(crate) fn $func_name(&self) -> $return_type { - match self { - $variant_name(v) => *v, - _ => panic!(concat!("called ", stringify!($func_name()), " on {:?}"), self), - } - } - } - } -} - -impl_union_unpack_as_value!(as_stack_boundary, Value::PrevStackBoundary, isize); -impl_union_unpack_as_value!(as_ref, Value::Ref, ValueId); -impl_union_unpack_as_value!(as_input, Value::Input, PortId); -impl_union_unpack_as_value!(as_output, Value::Output, PortId); -impl_union_unpack_as_value!(as_message, Value::Message, HeapPos); -impl_union_unpack_as_value!(as_bool, Value::Bool, bool); -impl_union_unpack_as_value!(as_char, Value::Char, char); -impl_union_unpack_as_value!(as_string, Value::String, HeapPos); -impl_union_unpack_as_value!(as_uint8, Value::UInt8, u8); -impl_union_unpack_as_value!(as_uint16, Value::UInt16, u16); -impl_union_unpack_as_value!(as_uint32, Value::UInt32, u32); -impl_union_unpack_as_value!(as_uint64, Value::UInt64, u64); -impl_union_unpack_as_value!(as_sint8, Value::SInt8, i8); -impl_union_unpack_as_value!(as_sint16, Value::SInt16, i16); -impl_union_unpack_as_value!(as_sint32, Value::SInt32, i32); -impl_union_unpack_as_value!(as_sint64, Value::SInt64, i64); -impl_union_unpack_as_value!(as_array, Value::Array, HeapPos); -impl_union_unpack_as_value!(as_tuple, Value::Tuple, HeapPos); -impl_union_unpack_as_value!(as_enum, Value::Enum, i64); -impl_union_unpack_as_value!(as_struct, Value::Struct, HeapPos); +union_cast_to_value_method_impl!(as_stack_boundary, isize, Value::PrevStackBoundary); +union_cast_to_value_method_impl!(as_ref, ValueId, Value::Ref); +union_cast_to_value_method_impl!(as_input, PortId, Value::Input); +union_cast_to_value_method_impl!(as_output, PortId, Value::Output); +union_cast_to_value_method_impl!(as_message, HeapPos, Value::Message); +union_cast_to_value_method_impl!(as_bool, bool, Value::Bool); +union_cast_to_value_method_impl!(as_char, char, Value::Char); +union_cast_to_value_method_impl!(as_string, HeapPos, Value::String); +union_cast_to_value_method_impl!(as_uint8, u8, Value::UInt8); +union_cast_to_value_method_impl!(as_uint16, u16, Value::UInt16); +union_cast_to_value_method_impl!(as_uint32, u32, Value::UInt32); +union_cast_to_value_method_impl!(as_uint64, u64, Value::UInt64); +union_cast_to_value_method_impl!(as_sint8, i8, Value::SInt8); +union_cast_to_value_method_impl!(as_sint16, i16, Value::SInt16); +union_cast_to_value_method_impl!(as_sint32, i32, Value::SInt32); +union_cast_to_value_method_impl!(as_sint64, i64, Value::SInt64); +union_cast_to_value_method_impl!(as_array, HeapPos, Value::Array); +union_cast_to_value_method_impl!(as_tuple, HeapPos, Value::Tuple); +union_cast_to_value_method_impl!(as_enum, i64, Value::Enum); +union_cast_to_value_method_impl!(as_struct, HeapPos, Value::Struct); impl Value { pub(crate) fn as_union(&self) -> (i64, HeapPos) { @@ -107,6 +94,14 @@ impl Value { } } + pub(crate) fn as_port_id(&self) -> PortId { + match self { + Value::Input(v) => *v, + Value::Output(v) => *v, + _ => unreachable!(), + } + } + pub(crate) fn is_integer(&self) -> bool { match self { Value::UInt8(_) | Value::UInt16(_) | Value::UInt32(_) | Value::UInt64(_) | diff --git a/src/protocol/mod.rs b/src/protocol/mod.rs index 28b49cea928f7667630eb86e32814cbb8792360b..8a479bec378d74d907e86d322134bde9c0e794a0 100644 --- a/src/protocol/mod.rs +++ b/src/protocol/mod.rs @@ -215,6 +215,9 @@ pub trait RunContext { fn fires(&mut self, port: PortId) -> Option; // None if not yet branched fn performed_fork(&mut self) -> Option; // None if not yet forked fn created_channel(&mut self) -> Option<(Value, Value)>; // None if not yet prepared + fn performed_select_start(&mut self) -> bool; // true if performed + fn performed_select_register_port(&mut self) -> bool; // true if registered + fn performed_select_wait(&mut self) -> Option; // None if not yet notified runtime of select blocker } pub struct ProtocolDescriptionBuilder { diff --git a/src/protocol/parser/pass_rewriting.rs b/src/protocol/parser/pass_rewriting.rs index d5080c1c04c63d750f292e144b621d08a2f66b4e..d4c57df2fb4049f5cad3aa38e938bcff0139d40d 100644 --- a/src/protocol/parser/pass_rewriting.rs +++ b/src/protocol/parser/pass_rewriting.rs @@ -291,6 +291,7 @@ impl Visitor for PassRewriting { // Final steps: set the statements of the replacement block statement, // and link all of those statements together + let first_stmt_id = transformed_stmts[0]; let mut last_stmt_id = transformed_stmts[0]; for stmt_id in transformed_stmts.iter().skip(1).copied() { set_ast_statement_next(ctx, last_stmt_id, stmt_id); @@ -298,6 +299,7 @@ impl Visitor for PassRewriting { } let outer_block_stmt = &mut ctx.heap[outer_block_id]; + outer_block_stmt.next = first_stmt_id; outer_block_stmt.statements = transformed_stmts; return Ok(()) diff --git a/src/protocol/parser/pass_typing.rs b/src/protocol/parser/pass_typing.rs index d1cf2c25d0c529597ddcb8945880bedb241529ee..984befe76a8648c5093a3bab15390769ffaadf36 100644 --- a/src/protocol/parser/pass_typing.rs +++ b/src/protocol/parser/pass_typing.rs @@ -872,22 +872,22 @@ enum InferenceRule { } impl InferenceRule { - union_cast_method_impl!(as_mono_template, InferenceRuleTemplate, InferenceRule::MonoTemplate); - union_cast_method_impl!(as_bi_equal, InferenceRuleBiEqual, InferenceRule::BiEqual); - union_cast_method_impl!(as_tri_equal_args, InferenceRuleTriEqualArgs, InferenceRule::TriEqualArgs); - union_cast_method_impl!(as_tri_equal_all, InferenceRuleTriEqualAll, InferenceRule::TriEqualAll); - union_cast_method_impl!(as_concatenate, InferenceRuleTwoArgs, InferenceRule::Concatenate); - union_cast_method_impl!(as_indexing_expr, InferenceRuleIndexingExpr, InferenceRule::IndexingExpr); - union_cast_method_impl!(as_slicing_expr, InferenceRuleSlicingExpr, InferenceRule::SlicingExpr); - union_cast_method_impl!(as_select_struct_field, InferenceRuleSelectStructField, InferenceRule::SelectStructField); - union_cast_method_impl!(as_select_tuple_member, InferenceRuleSelectTupleMember, InferenceRule::SelectTupleMember); - union_cast_method_impl!(as_literal_struct, InferenceRuleLiteralStruct, InferenceRule::LiteralStruct); - union_cast_method_impl!(as_literal_union, InferenceRuleLiteralUnion, InferenceRule::LiteralUnion); - union_cast_method_impl!(as_literal_array, InferenceRuleLiteralArray, InferenceRule::LiteralArray); - union_cast_method_impl!(as_literal_tuple, InferenceRuleLiteralTuple, InferenceRule::LiteralTuple); - union_cast_method_impl!(as_cast_expr, InferenceRuleCastExpr, InferenceRule::CastExpr); - union_cast_method_impl!(as_call_expr, InferenceRuleCallExpr, InferenceRule::CallExpr); - union_cast_method_impl!(as_variable_expr, InferenceRuleVariableExpr, InferenceRule::VariableExpr); + union_cast_to_ref_method_impl!(as_mono_template, InferenceRuleTemplate, InferenceRule::MonoTemplate); + union_cast_to_ref_method_impl!(as_bi_equal, InferenceRuleBiEqual, InferenceRule::BiEqual); + union_cast_to_ref_method_impl!(as_tri_equal_args, InferenceRuleTriEqualArgs, InferenceRule::TriEqualArgs); + union_cast_to_ref_method_impl!(as_tri_equal_all, InferenceRuleTriEqualAll, InferenceRule::TriEqualAll); + union_cast_to_ref_method_impl!(as_concatenate, InferenceRuleTwoArgs, InferenceRule::Concatenate); + union_cast_to_ref_method_impl!(as_indexing_expr, InferenceRuleIndexingExpr, InferenceRule::IndexingExpr); + union_cast_to_ref_method_impl!(as_slicing_expr, InferenceRuleSlicingExpr, InferenceRule::SlicingExpr); + union_cast_to_ref_method_impl!(as_select_struct_field, InferenceRuleSelectStructField, InferenceRule::SelectStructField); + union_cast_to_ref_method_impl!(as_select_tuple_member, InferenceRuleSelectTupleMember, InferenceRule::SelectTupleMember); + union_cast_to_ref_method_impl!(as_literal_struct, InferenceRuleLiteralStruct, InferenceRule::LiteralStruct); + union_cast_to_ref_method_impl!(as_literal_union, InferenceRuleLiteralUnion, InferenceRule::LiteralUnion); + union_cast_to_ref_method_impl!(as_literal_array, InferenceRuleLiteralArray, InferenceRule::LiteralArray); + union_cast_to_ref_method_impl!(as_literal_tuple, InferenceRuleLiteralTuple, InferenceRule::LiteralTuple); + union_cast_to_ref_method_impl!(as_cast_expr, InferenceRuleCastExpr, InferenceRule::CastExpr); + union_cast_to_ref_method_impl!(as_call_expr, InferenceRuleCallExpr, InferenceRule::CallExpr); + union_cast_to_ref_method_impl!(as_variable_expr, InferenceRuleVariableExpr, InferenceRule::VariableExpr); } // Note: InferenceRuleTemplate is `Copy`, so don't add dynamically allocated diff --git a/src/protocol/tests/utils.rs b/src/protocol/tests/utils.rs index 3767cdc6efdb6d417b055eb1724470c94929a0d4..74ae7df0032656594ca7bc8ced624f23bc6967f9 100644 --- a/src/protocol/tests/utils.rs +++ b/src/protocol/tests/utils.rs @@ -1267,23 +1267,12 @@ fn seek_expr_in_stmt bool>(heap: &Heap, start: StatementId struct FakeRunContext{} impl RunContext for FakeRunContext { - fn performed_put(&mut self, _port: PortId) -> bool { - unreachable!("'put' called in compiler testing code") - } - - fn performed_get(&mut self, _port: PortId) -> Option { - unreachable!("'get' called in compiler testing code") - } - - fn fires(&mut self, _port: PortId) -> Option { - unreachable!("'fires' called in compiler testing code") - } - - fn performed_fork(&mut self) -> Option { - unreachable!("'fork' called in compiler testing code") - } - - fn created_channel(&mut self) -> Option<(Value, Value)> { - unreachable!("channel created in compiler testing code") - } + fn performed_put(&mut self, _port: PortId) -> bool { unreachable!() } + fn performed_get(&mut self, _port: PortId) -> Option { unreachable!() } + fn fires(&mut self, _port: PortId) -> Option { unreachable!() } + fn performed_fork(&mut self) -> Option { unreachable!() } + fn created_channel(&mut self) -> Option<(Value, Value)> { unreachable!() } + fn performed_select_start(&mut self) -> bool { unreachable!() } + fn performed_select_register_port(&mut self) -> bool { unreachable!() } + fn performed_select_wait(&mut self) -> Option { unreachable!() } } \ No newline at end of file diff --git a/src/runtime/connector.rs b/src/runtime/connector.rs index f495ce8faa4812601639dba56217514833284aec..9809dd8ced586b0b5480b1f2278753cb527cc2ee 100644 --- a/src/runtime/connector.rs +++ b/src/runtime/connector.rs @@ -122,6 +122,10 @@ impl<'a> RunContext for ConnectorRunContext<'a>{ taken => unreachable!("prepared statement is '{:?}' during 'performed_fork()'", taken), }; } + + fn performed_select_start(&mut self) -> bool { unreachable!() } + fn performed_select_register_port(&mut self) -> bool { unreachable!() } + fn performed_select_wait(&mut self) -> Option { unreachable!() } } impl Connector for ConnectorPDL { diff --git a/src/runtime2/component/component_pdl.rs b/src/runtime2/component/component_pdl.rs index f53b091dacfa1cb043a387b893c130c36fd55614..011ce9823aeb18f74da57befad141de61da4ef0d 100644 --- a/src/runtime2/component/component_pdl.rs +++ b/src/runtime2/component/component_pdl.rs @@ -24,6 +24,9 @@ pub enum ExecStmt { CreatedChannel((Value, Value)), PerformedPut, PerformedGet(ValueGroup), + PerformedSelectStart, + PerformedSelectRegister, + PerformedSelectWait(u32), None, } @@ -78,6 +81,30 @@ impl RunContext for ExecCtx { _ => unreachable!(), } } + + fn performed_select_start(&mut self) -> bool { + match self.stmt.take() { + ExecStmt::None => return false, + ExecStmt::PerformedSelectStart => return true, + _ => unreachable!(), + } + } + + fn performed_select_register_port(&mut self) -> bool { + match self.stmt.take() { + ExecStmt::None => return false, + ExecStmt::PerformedSelectRegister => return true, + _ => unreachable!(), + } + } + + fn performed_select_wait(&mut self) -> Option { + match self.stmt.take() { + ExecStmt::None => return None, + ExecStmt::PerformedSelectWait(selected_case) => Some(selected_case), + _ => unreachable!(), + } + } } #[derive(Debug, Copy, Clone, PartialEq, Eq)] @@ -245,6 +272,19 @@ impl CompPDL { return Ok(CompScheduling::Immediate); } }, + EC::SelectStart(num_cases, num_ports) => { + debug_assert_eq!(self.mode, Mode::Sync); + todo!("finish handling select start") + }, + EC::SelectRegisterPort(case_index, port_index, port_id) => { + debug_assert_eq!(self.mode, Mode::Sync); + todo!("finish handling register port") + }, + EC::SelectWait => { + debug_assert_eq!(self.mode, Mode::Sync); + self.handle_select_wait(sched_ctx, comp_ctx); + todo!("finish handling select wait") + }, // Results that can be returned outside of sync mode EC::ComponentTerminated => { self.mode = Mode::StartExit; // next call we'll take care of the exit @@ -332,6 +372,13 @@ impl CompPDL { } } + /// Handles the moment where the PDL code has notified the runtime of all + /// the ports it is waiting on. + fn handle_select_wait(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx) { + sched_ctx.log("Component waiting for select conclusion"); + + } + fn handle_component_exit(&mut self, sched_ctx: &SchedulerCtx, comp_ctx: &mut CompCtx) { sched_ctx.log("Component exiting"); debug_assert_eq!(self.mode, Mode::StartExit); diff --git a/src/runtime2/component/consensus.rs b/src/runtime2/component/consensus.rs index 3ad9244b7ed94ce7b697a704fce159a7358b7d5d..b2bdfc0e49dedcc60515ed2e41aa32c3c5f874cb 100644 --- a/src/runtime2/component/consensus.rs +++ b/src/runtime2/component/consensus.rs @@ -30,6 +30,8 @@ enum Mode { NonSync, SyncBusy, SyncAwaitingSolution, + SelectBusy, + SelectWait, } struct SolutionCombiner { diff --git a/src/runtime2/tests/mod.rs b/src/runtime2/tests/mod.rs index 007407b42484808e587f36c85deed563a28efd0d..ed151bc3431831455d0c16588626da6b1c9778a9 100644 --- a/src/runtime2/tests/mod.rs +++ b/src/runtime2/tests/mod.rs @@ -82,4 +82,53 @@ fn test_component_communication() { }").expect("compilation"); let rt = Runtime::new(3, true, pd); create_component(&rt, "", "constructor", no_args()); +} + +#[test] +fn test_simple_select() { + let pd = ProtocolDescription::parse(b" + func infinite_assert(T val, T expected) -> () { + while (val != expected) { print(\"nope!\"); } + } + + primitive receiver(in in_a, in in_b, u32 num_sends) { + auto num_from_a = 0; + auto num_from_b = 0; + while (num_from_a + num_from_b < 2 * num_sends) { + sync select { + auto v = get(in_a) -> { + print(\"got something from A\"); + infinite_assert(v, num_from_a); + num_from_a += 1; + } + auto v = get(in_b) -> { + print(\"got something from B\"); + infinite_assert(v, num_from_b); + num_from_b +=1; + } + } + } + } + + primitive sender(out tx, u32 num_sends) { + auto index = 0; + while (index < num_sends) { + sync { + put(tx, index); + index += 1; + } + } + } + + composite constructor() { + auto num_sends = 3; + channel tx_a -> rx_a; + channel tx_b -> rx_b; + new sender(tx_a, num_sends); + new receiver(rx_a, rx_b, num_sends); + new sender(tx_b, num_sends); + } + ").expect("compilation"); + let rt = Runtime::new(1, true, pd); + create_component(&rt, "", "constructor", no_args()); } \ No newline at end of file