diff --git a/src/runtime/serde.rs b/src/runtime/serde.rs new file mode 100644 index 0000000000000000000000000000000000000000..d44f535bed337d08a847d52a5bcaefd104872ffc --- /dev/null +++ b/src/runtime/serde.rs @@ -0,0 +1,308 @@ +use crate::common::*; +use crate::runtime::{ + endpoint::{CommMsg, CommMsgContents, EndpointInfo, Msg, SetupMsg}, + Predicate, +}; +use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt}; +use std::io::{ErrorKind::InvalidData, Read, Write}; + +pub trait Ser: Write { + fn ser(&mut self, t: &T) -> Result<(), std::io::Error>; +} +pub trait De: Read { + fn de(&mut self) -> Result; +} + +pub struct MonitoredReader { + bytes: usize, + r: R, +} +impl From for MonitoredReader { + fn from(r: R) -> Self { + Self { r, bytes: 0 } + } +} +impl MonitoredReader { + pub fn bytes_read(&self) -> usize { + self.bytes + } +} +impl Read for MonitoredReader { + fn read(&mut self, buf: &mut [u8]) -> Result { + let n = self.r.read(buf)?; + self.bytes += n; + Ok(n) + } +} + +///////////////////////////////////////// + +macro_rules! ser_seq { + ( $w:expr ) => {{ + io::Result::Ok(()) + }}; + ( $w:expr, $first:expr ) => {{ + $w.ser($first) + }}; + ( $w:expr, $first:expr, $( $x:expr ),+ ) => {{ + $w.ser($first)?; + ser_seq![$w, $( $x ),*] + }}; +} +///////////////////////////////////////// + +impl Ser for W { + fn ser(&mut self, t: &u8) -> Result<(), std::io::Error> { + self.write_u8(*t) + } +} +impl De for R { + fn de(&mut self) -> Result { + self.read_u8() + } +} + +impl Ser for W { + fn ser(&mut self, t: &u16) -> Result<(), std::io::Error> { + self.write_u16::(*t) + } +} +impl De for R { + fn de(&mut self) -> Result { + self.read_u16::() + } +} + +impl Ser for W { + fn ser(&mut self, t: &u32) -> Result<(), std::io::Error> { + self.write_u32::(*t) + } +} +impl De for R { + fn de(&mut self) -> Result { + self.read_u32::() + } +} + +impl Ser for W { + fn ser(&mut self, t: &u64) -> Result<(), std::io::Error> { + self.write_u64::(*t) + } +} +impl De for R { + fn de(&mut self) -> Result { + self.read_u64::() + } +} + +impl Ser for W { + fn ser(&mut self, t: &Payload) -> Result<(), std::io::Error> { + self.ser(&ZigZag(t.len() as u64))?; + for byte in t { + self.ser(byte)?; + } + Ok(()) + } +} +impl De for R { + fn de(&mut self) -> Result { + let ZigZag(len) = self.de()?; + let mut x = Vec::with_capacity(len as usize); + for _ in 0..len { + x.push(self.de()?); + } + Ok(x) + } +} + +struct ZigZag(u64); +impl Ser for W { + fn ser(&mut self, t: &ZigZag) -> Result<(), std::io::Error> { + integer_encoding::VarIntWriter::write_varint(self, t.0).map(|_| ()) + } +} +impl De for R { + fn de(&mut self) -> Result { + integer_encoding::VarIntReader::read_varint(self).map(ZigZag) + } +} + +impl Ser for W { + fn ser(&mut self, t: &ChannelId) -> Result<(), std::io::Error> { + self.ser(&t.controller_id)?; + self.ser(&ZigZag(t.channel_index as u64)) + } +} +impl De for R { + fn de(&mut self) -> Result { + Ok(ChannelId { + controller_id: self.de()?, + channel_index: De::::de(self)?.0 as ChannelIndex, + }) + } +} + +impl Ser for W { + fn ser(&mut self, t: &bool) -> Result<(), std::io::Error> { + self.ser(&match t { + true => b'T', + false => b'F', + }) + } +} +impl De for R { + fn de(&mut self) -> Result { + let b: u8 = self.de()?; + Ok(match b { + b'T' => true, + b'F' => false, + _ => return Err(InvalidData.into()), + }) + } +} + +impl Ser for W { + fn ser(&mut self, t: &Predicate) -> Result<(), std::io::Error> { + self.ser(&ZigZag(t.assigned.len() as u64))?; + for (channel_id, boolean) in &t.assigned { + ser_seq![self, channel_id, boolean]?; + } + Ok(()) + } +} +impl De for R { + fn de(&mut self) -> Result { + let ZigZag(len) = self.de()?; + let mut assigned = BTreeMap::::default(); + for _ in 0..len { + assigned.insert(self.de()?, self.de()?); + } + Ok(Predicate { assigned }) + } +} + +impl Ser for W { + fn ser(&mut self, t: &Polarity) -> Result<(), std::io::Error> { + self.ser(&match t { + Polarity::Putter => b'P', + Polarity::Getter => b'G', + }) + } +} +impl De for R { + fn de(&mut self) -> Result { + let b: u8 = self.de()?; + Ok(match b { + b'P' => Polarity::Putter, + b'G' => Polarity::Getter, + _ => return Err(InvalidData.into()), + }) + } +} + +impl Ser for W { + fn ser(&mut self, t: &EndpointInfo) -> Result<(), std::io::Error> { + let EndpointInfo { channel_id, polarity } = t; + ser_seq![self, channel_id, polarity] + } +} +impl De for R { + fn de(&mut self) -> Result { + Ok(EndpointInfo { channel_id: self.de()?, polarity: self.de()? }) + } +} + +impl Ser for W { + fn ser(&mut self, t: &Msg) -> Result<(), std::io::Error> { + use {CommMsgContents::*, SetupMsg::*}; + match t { + Msg::SetupMsg(s) => match s { + ChannelSetup { info } => ser_seq![self, &0u8, info], + LeaderEcho { maybe_leader } => ser_seq![self, &1u8, maybe_leader], + LeaderAnnounce { leader } => ser_seq![self, &2u8, leader], + YouAreMyParent => ser_seq![self, &3u8], + }, + Msg::CommMsg(CommMsg { round_index, contents }) => { + let zig = &ZigZag(*round_index as u64); + match contents { + SendPayload { payload_predicate, payload } => { + ser_seq![self, &4u8, zig, payload_predicate, payload] + } + Elaborate { partial_oracle } => ser_seq![self, &5u8, zig, partial_oracle], + Announce { oracle } => ser_seq![self, &6u8, zig, oracle], + } + } + } + } +} +impl De for R { + fn de(&mut self) -> Result { + use {CommMsgContents::*, SetupMsg::*}; + let b: u8 = self.de()?; + Ok(match b { + 0..=3 => Msg::SetupMsg(match b { + 0 => ChannelSetup { info: self.de()? }, + 1 => LeaderEcho { maybe_leader: self.de()? }, + 2 => LeaderAnnounce { leader: self.de()? }, + 3 => YouAreMyParent, + _ => unreachable!(), + }), + _ => { + let ZigZag(zig) = self.de()?; + let contents = match b { + 4 => SendPayload { payload_predicate: self.de()?, payload: self.de()? }, + 5 => Elaborate { partial_oracle: self.de()? }, + 6 => Announce { oracle: self.de()? }, + _ => return Err(InvalidData.into()), + }; + Msg::CommMsg(CommMsg { round_index: zig as usize, contents }) + } + }) + } +} + +///////////////// + +// #[test] +// fn my_serde() -> Result<(), std::io::Error> { +// let payload_predicate = Predicate { +// assigned: maplit::btreemap! { ChannelId {controller_id: 3, channel_index: 9} => false }, +// }; +// let msg = Msg::CommMsg(CommMsg { +// round_index: !0, +// contents: CommMsgContents::SendPayload { +// payload_predicate, +// payload: (0..).take(2).collect(), +// }, +// }); +// let mut v = vec![]; +// v.ser(&msg)?; +// print!("["); +// for (i, &x) in v.iter().enumerate() { +// print!("{:02x}", x); +// if i % 4 == 3 { +// print!(" "); +// } +// } +// println!("]"); + +// let msg2: Msg = (&v[..]).de()?; +// println!("msg2 {:#?}", msg2); +// Ok(()) +// } + +// #[test] +// fn varint() { +// let mut v = vec![]; +// v.ser(&ZigZag(!0)).unwrap(); +// for (i, x) in v.iter_mut().enumerate() { +// print!("{:02x}", x); +// if i % 4 == 3 { +// print!(" "); +// } +// } +// *v.iter_mut().last().unwrap() |= 3; + +// let ZigZag(x) = De::de(&mut &v[..]).unwrap(); +// println!(""); +// }