diff --git a/src/runtime/communication.rs b/src/runtime/communication.rs index 6df77b1bb4258c18b5b0bde62ac79e5d09c250ff..442a458f12552ea7ff28ca8447569f367c4a66b4 100644 --- a/src/runtime/communication.rs +++ b/src/runtime/communication.rs @@ -4,18 +4,22 @@ use crate::runtime::{actors::*, endpoint::*, errors::*, *}; impl Controller { fn end_round_with_decision(&mut self, decision: Decision) -> Result<(), SyncErr> { log!(&mut self.inner.logger, "ENDING ROUND WITH DECISION! {:?}", &decision); - if let Decision::Success(predicate) = &decision { - // overwrite MonoN/P - self.inner.mono_n = - self.ephemeral.poly_n.take().unwrap().choose_mono(predicate).unwrap(); - self.inner.mono_ps.clear(); - self.inner.mono_ps.extend( - self.ephemeral - .poly_ps - .drain(..) - .map(|poly_p| poly_p.choose_mono(predicate).unwrap()), - ) - } + let ret = match &decision { + Decision::Success(predicate) => { + // overwrite MonoN/P + self.inner.mono_n = + self.ephemeral.poly_n.take().unwrap().choose_mono(predicate).unwrap(); + self.inner.mono_ps.clear(); + self.inner.mono_ps.extend( + self.ephemeral + .poly_ps + .drain(..) + .map(|poly_p| poly_p.choose_mono(predicate).unwrap()), + ); + Ok(()) + } + Decision::Failure => Err(SyncErr::Timeout), + }; let announcement = CommMsgContents::Announce { decision }.into_msg(self.inner.round_index); for &child_ekey in self.inner.family.children_ekeys.iter() { log!( @@ -33,7 +37,7 @@ impl Controller { } self.inner.round_index += 1; self.ephemeral.clear(); - Ok(()) + ret } // Drain self.ephemeral.solution_storage and handle the new locals. Return decision if one is found @@ -129,7 +133,7 @@ impl Controller { // If a native requires setting up, arg `sync_batches` is Some, and those are used as the sync batches. pub fn sync_round( &mut self, - deadline: Instant, + mut deadline: Option, sync_batches: Option>, ) -> Result<(), SyncErr> { log!( @@ -264,15 +268,43 @@ impl Controller { log!(&mut self.inner.logger, "`No decision yet`. Time to recv messages"); self.undelay_all(); 'recv_loop: loop { - log!(&mut self.inner.logger, "`POLLING`..."); - let received = self.recv(deadline)?.ok_or_else(|| { - log!(&mut self.inner.logger, ":( timing out"); - SyncErr::Timeout - })?; + log!(&mut self.inner.logger, "`POLLING` with deadline {:?}...", deadline); + let received = match deadline { + Some(d) => match self.recv(d)? { + Some(received) => received, + None => { + deadline = None; + match self.inner.family.parent_ekey { + Some(parent_ekey) => { + let announcement = Msg::CommMsg(CommMsg { + round_index: self.inner.round_index, + contents: CommMsgContents::Failure, + }); + log!( + &mut self.inner.logger, + "Forwarding {:?} to parent with ekey {:?}", + &announcement, + parent_ekey + ); + self.inner + .endpoint_exts + .get_mut(parent_ekey) + .expect("ss") + .endpoint + .send(announcement.clone())?; + } + None => return self.end_round_with_decision(Decision::Failure), + } + continue; + } + }, + None => self.recv(Instant::now() + Duration::from_secs(2))?.expect("DRIED UP"), + }; log!(&mut self.inner.logger, "::: message {:?}...", &received); let current_content = match received.msg { - Msg::SetupMsg(_) => { + Msg::SetupMsg(s) => { // This occurs in the event the connector was malformed during connect() + println!("WASNT EXPECTING {:?}", s); return Err(SyncErr::UnexpectedSetupMsg); } Msg::CommMsg(CommMsg { round_index, .. }) @@ -451,25 +483,6 @@ impl Controller { } } } - // 'timeout_loop: loop { - // log!(&mut self.inner.logger, "`POLLING (already timed out)`..."); - // let received = self.recv_blocking()?; - // log!(&mut self.inner.logger, "::: message {:?}...", &received); - // let current_content = match received.msg { - // Msg::SetupMsg(_) => { - // // This occurs in the event the connector was malformed during connect() - // return Err(SyncErr::UnexpectedSetupMsg); - // } - // Msg::CommMsg(CommMsg { round_index, contents }) => { - // if round_index > self.inner.round_index { - // self.delay(received); - // continue 'timeout_loop; - // } else { - // contents - // } - // } - // }; - // } } } impl ControllerEphemeral { diff --git a/src/runtime/connector.rs b/src/runtime/connector.rs index e56113e1298926048fe66fd8158b186478b905e0..6b27ba61f760e759a4f6c41bd1308974b2031342 100644 --- a/src/runtime/connector.rs +++ b/src/runtime/connector.rs @@ -115,7 +115,7 @@ impl Connector { if native_polarity != Putter { return Err(WrongPolarity); } - let sync_batch = connected.sync_batches.iter_mut().last().unwrap(); + let sync_batch = connected.sync_batches.iter_mut().last().expect("no sync batch!"); if sync_batch.puts.contains_key(&ekey) { return Err(DuplicateOperation); } @@ -134,7 +134,7 @@ impl Connector { if native_polarity != Getter { return Err(WrongPolarity); } - let sync_batch = connected.sync_batches.iter_mut().last().unwrap(); + let sync_batch = connected.sync_batches.iter_mut().last().expect("no sync batch!"); if sync_batch.gets.contains(&ekey) { return Err(DuplicateOperation); } @@ -159,9 +159,11 @@ impl Connector { }; // do the synchronous round! - connected.controller.sync_round(deadline, Some(connected.sync_batches.drain(..)))?; + let res = + connected.controller.sync_round(Some(deadline), Some(connected.sync_batches.drain(..))); connected.sync_batches.push(SyncBatch::default()); - Ok(connected.controller.inner.mono_n.result.as_mut().unwrap().0) + res?; + Ok(connected.controller.inner.mono_n.result.as_mut().expect("qqqs").0) } pub fn read_gotten(&self, native_port_index: usize) -> Result<&[u8], ReadGottenErr> { diff --git a/src/runtime/serde.rs b/src/runtime/serde.rs index 589c5f0f5c639310fd0b087e78403e1f50249a54..7be378e7a82e42ccbfd9187a07e5485cf33a735c 100644 --- a/src/runtime/serde.rs +++ b/src/runtime/serde.rs @@ -37,6 +37,8 @@ impl Read for MonitoredReader { ///////////////////////////////////////// +struct VarLenInt(u64); + macro_rules! ser_seq { ( $w:expr ) => {{ io::Result::Ok(()) @@ -51,6 +53,25 @@ macro_rules! ser_seq { } ///////////////////////////////////////// +impl Ser for W { + fn ser(&mut self, t: &bool) -> Result<(), std::io::Error> { + self.ser(&match t { + true => b'T', + false => b'F', + }) + } +} +impl De for R { + fn de(&mut self) -> Result { + let b: u8 = self.de()?; + Ok(match b { + b'T' => true, + b'F' => false, + _ => return Err(InvalidData.into()), + }) + } +} + impl Ser for W { fn ser(&mut self, t: &u8) -> Result<(), std::io::Error> { self.write_u8(*t) @@ -97,7 +118,7 @@ impl De for R { impl Ser for W { fn ser(&mut self, t: &Payload) -> Result<(), std::io::Error> { - self.ser(&ZigZag(t.len() as u64))?; + self.ser(&VarLenInt(t.len() as u64))?; for byte in t { self.ser(byte)?; } @@ -106,7 +127,7 @@ impl Ser for W { } impl De for R { fn de(&mut self) -> Result { - let ZigZag(len) = self.de()?; + let VarLenInt(len) = self.de()?; let mut x = Vec::with_capacity(len as usize); for _ in 0..len { x.push(self.de()?); @@ -115,55 +136,35 @@ impl De for R { } } -struct ZigZag(u64); -impl Ser for W { - fn ser(&mut self, t: &ZigZag) -> Result<(), std::io::Error> { +impl Ser for W { + fn ser(&mut self, t: &VarLenInt) -> Result<(), std::io::Error> { integer_encoding::VarIntWriter::write_varint(self, t.0).map(|_| ()) } } -impl De for R { - fn de(&mut self) -> Result { - integer_encoding::VarIntReader::read_varint(self).map(ZigZag) +impl De for R { + fn de(&mut self) -> Result { + integer_encoding::VarIntReader::read_varint(self).map(VarLenInt) } } impl Ser for W { fn ser(&mut self, t: &ChannelId) -> Result<(), std::io::Error> { self.ser(&t.controller_id)?; - self.ser(&ZigZag(t.channel_index as u64)) + self.ser(&VarLenInt(t.channel_index as u64)) } } impl De for R { fn de(&mut self) -> Result { Ok(ChannelId { controller_id: self.de()?, - channel_index: De::::de(self)?.0 as ChannelIndex, - }) - } -} - -impl Ser for W { - fn ser(&mut self, t: &bool) -> Result<(), std::io::Error> { - self.ser(&match t { - true => b'T', - false => b'F', - }) - } -} -impl De for R { - fn de(&mut self) -> Result { - let b: u8 = self.de()?; - Ok(match b { - b'T' => true, - b'F' => false, - _ => return Err(InvalidData.into()), + channel_index: De::::de(self)?.0 as ChannelIndex, }) } } impl Ser for W { fn ser(&mut self, t: &Predicate) -> Result<(), std::io::Error> { - self.ser(&ZigZag(t.assigned.len() as u64))?; + self.ser(&VarLenInt(t.assigned.len() as u64))?; for (channel_id, boolean) in &t.assigned { ser_seq![self, channel_id, boolean]?; } @@ -172,7 +173,7 @@ impl Ser for W { } impl De for R { fn de(&mut self) -> Result { - let ZigZag(len) = self.de()?; + let VarLenInt(len) = self.de()?; let mut assigned = BTreeMap::::default(); for _ in 0..len { assigned.insert(self.de()?, self.de()?); @@ -238,20 +239,22 @@ impl Ser for W { use {CommMsgContents::*, SetupMsg::*}; match t { Msg::SetupMsg(s) => match s { + // [flag, data] ChannelSetup { info } => ser_seq![self, &0u8, info], LeaderEcho { maybe_leader } => ser_seq![self, &1u8, maybe_leader], LeaderAnnounce { leader } => ser_seq![self, &2u8, leader], YouAreMyParent => ser_seq![self, &3u8], }, Msg::CommMsg(CommMsg { round_index, contents }) => { - let zig = &ZigZag(*round_index as u64); + // [flag, round_num, data] + let varlenint = &VarLenInt(*round_index as u64); match contents { SendPayload { payload_predicate, payload } => { - ser_seq![self, &4u8, zig, payload_predicate, payload] + ser_seq![self, &4u8, varlenint, payload_predicate, payload] } - Elaborate { partial_oracle } => ser_seq![self, &5u8, zig, partial_oracle], - Announce { decision } => ser_seq![self, &6u8, zig, decision], - Failure => ser_seq![self, &7u8], + Elaborate { partial_oracle } => ser_seq![self, &5u8, varlenint, partial_oracle], + Announce { decision } => ser_seq![self, &6u8, varlenint, decision], + Failure => ser_seq![self, &7u8, varlenint], } } } @@ -263,23 +266,26 @@ impl De for R { let b: u8 = self.de()?; Ok(match b { 0..=3 => Msg::SetupMsg(match b { - 0 => ChannelSetup { info: self.de()? }, - 1 => LeaderEcho { maybe_leader: self.de()? }, - 2 => LeaderAnnounce { leader: self.de()? }, - 3 => YouAreMyParent, + // [flag, data] + 0u8 => ChannelSetup { info: self.de()? }, + 1u8 => LeaderEcho { maybe_leader: self.de()? }, + 2u8 => LeaderAnnounce { leader: self.de()? }, + 3u8 => YouAreMyParent, _ => unreachable!(), }), - _ => { - let ZigZag(zig) = self.de()?; + 4..=7 => { + // [flag, round_num, data] + let VarLenInt(varlenint) = self.de()?; let contents = match b { - 4 => SendPayload { payload_predicate: self.de()?, payload: self.de()? }, - 5 => Elaborate { partial_oracle: self.de()? }, - 6 => Announce { decision: self.de()? }, - 7 => Failure, - _ => return Err(InvalidData.into()), + 4u8 => SendPayload { payload_predicate: self.de()?, payload: self.de()? }, + 5u8 => Elaborate { partial_oracle: self.de()? }, + 6u8 => Announce { decision: self.de()? }, + 7u8 => Failure, + _ => unreachable!(), }; - Msg::CommMsg(CommMsg { round_index: zig as usize, contents }) + Msg::CommMsg(CommMsg { round_index: varlenint as usize, contents }) } + _ => return Err(InvalidData.into()), }) } } diff --git a/src/test/connector.rs b/src/test/connector.rs index ac764a75602b30085b938d09b5eaa418579814f0..1fe8ef31d27b5795c30326b36233e82636d9079e 100644 --- a/src/test/connector.rs +++ b/src/test/connector.rs @@ -775,3 +775,58 @@ fn connector_causal_loop2() { }, ])); } + +#[test] +fn connector_recover() { + let connect_timeout = Duration::from_millis(1500); + let comm_timeout = Duration::from_millis(300); + let addrs = [next_addr()]; + fn putter_does(i: usize) -> bool { + i % 3 == 0 + } + fn getter_does(i: usize) -> bool { + i % 2 == 0 + } + fn expect_res(i: usize) -> Result { + if putter_does(i) && getter_does(i) { + Ok(0) + } else { + Err(SyncErr::Timeout) + } + } + const N: usize = 11; + assert!(run_connector_set(&[ + // + &|x| { + // Alice + x.configure(PDL, b"forward").unwrap(); + x.bind_port(0, Native).unwrap(); + x.bind_port(1, Passive(addrs[0])).unwrap(); + x.connect(connect_timeout).unwrap(); + + for i in 0..N { + if putter_does(i) { + assert_eq!(Ok(()), x.put(0, b"msg".to_vec())); + } + assert_eq!(expect_res(i), x.sync(comm_timeout)); + } + }, + &|x| { + // Bob + x.configure(PDL, b"forward").unwrap(); + x.bind_port(0, Active(addrs[0])).unwrap(); + x.bind_port(1, Native).unwrap(); + x.connect(connect_timeout).unwrap(); + + for i in 0..N { + if getter_does(i) { + assert_eq!(Ok(()), x.get(0)); + } + assert_eq!(expect_res(i), x.sync(comm_timeout)); + if expect_res(i).is_ok() { + assert_eq!(Ok(b"msg" as &[u8]), x.read_gotten(0)); + } + } + }, + ])); +}