use super::vec_storage::VecStorage; use crate::common::*; use crate::runtime::endpoint::EndpointExt; use crate::runtime::endpoint::EndpointInfo; use crate::runtime::endpoint::{Endpoint, Msg, SetupMsg}; use crate::runtime::errors::EndpointErr; use crate::runtime::errors::MessengerRecvErr; use crate::runtime::errors::PollDeadlineErr; use crate::runtime::MessengerState; use crate::runtime::Messengerlike; use crate::runtime::ReceivedMsg; use crate::runtime::{ProtocolD, ProtocolS}; use std::net::SocketAddr; use std::sync::Arc; pub enum Coupling { Active, Passive, } #[derive(Debug)] struct Family { parent: Option, children: HashSet, } pub struct Binding { pub coupling: Coupling, pub polarity: Polarity, pub addr: SocketAddr, } pub struct InPort(Port); // InPort and OutPort are AFFINE (exposed to Rust API) pub struct OutPort(Port); impl From for Port { fn from(x: InPort) -> Self { x.0 } } impl From for Port { fn from(x: OutPort) -> Self { x.0 } } #[derive(Default, Debug)] struct ChannelIndexStream { next: u32, } impl ChannelIndexStream { fn next(&mut self) -> u32 { self.next += 1; self.next - 1 } } enum Connector { Connecting(Connecting), Connected(Connected), } #[derive(Default)] pub struct Connecting { bindings: Vec, } trait Binds { fn bind(&mut self, coupling: Coupling, addr: SocketAddr) -> T; } impl Binds for Connecting { fn bind(&mut self, coupling: Coupling, addr: SocketAddr) -> InPort { self.bindings.push(Binding { coupling, polarity: Polarity::Getter, addr }); InPort(Port(self.bindings.len() - 1)) } } impl Binds for Connecting { fn bind(&mut self, coupling: Coupling, addr: SocketAddr) -> OutPort { self.bindings.push(Binding { coupling, polarity: Polarity::Putter, addr }); OutPort(Port(self.bindings.len() - 1)) } } #[derive(Debug, Clone)] pub enum ConnectErr { BindErr(SocketAddr), NewSocketErr(SocketAddr), AcceptErr(SocketAddr), ConnectionShutdown(SocketAddr), PortKindMismatch(Port, SocketAddr), EndpointErr(Port, EndpointErr), PollInitFailed, PollingFailed, Timeout, } #[derive(Debug)] struct Component { protocol: Arc, port_set: HashSet, identifier: Arc<[u8]>, state: ProtocolS, } impl From for ConnectErr { fn from(e: PollDeadlineErr) -> Self { use PollDeadlineErr as P; match e { P::PollingFailed => Self::PollingFailed, P::Timeout => Self::Timeout, } } } impl From for ConnectErr { fn from(e: MessengerRecvErr) -> Self { use MessengerRecvErr as M; match e { M::PollingFailed => Self::PollingFailed, M::EndpointErr(port, err) => Self::EndpointErr(port, err), } } } impl Connecting { fn random_controller_id() -> ControllerId { type Bytes8 = [u8; std::mem::size_of::()]; let mut bytes = Bytes8::default(); getrandom::getrandom(&mut bytes).unwrap(); unsafe { // safe: // 1. All random bytes give valid Bytes8 // 2. Bytes8 and ControllerId have same valid representations std::mem::transmute::(bytes) } } fn test_stream_connectivity(stream: &mut TcpStream) -> bool { use std::io::Write; stream.write(&[]).is_ok() } fn new_connected( &self, controller_id: ControllerId, timeout: Option, ) -> Result { use ConnectErr::*; /////////////////////////////////////////////////////// // 1. bindings correspond with ports 0..bindings.len(). For each: // - reserve a slot in endpoint_exts. // - store the port in `native_ports' set. let mut endpoint_exts = VecStorage::::with_reserved_range(self.bindings.len()); let native_ports = (0..self.bindings.len()).map(Port).collect(); // 2. create MessengerState structure for polling channels let edge = PollOpt::edge(); let [ready_r, ready_w] = [Ready::readable(), Ready::writable()]; let mut ms = MessengerState::with_event_capacity(self.bindings.len()).map_err(|_| PollInitFailed)?; // 3. create one TODO task per (port,binding) as a vector with indices in lockstep. // we will drain it gradually so we store elements of type Option where all are initially Some(_) enum Todo { PassiveAccepting { listener: TcpListener, channel_id: ChannelId }, ActiveConnecting { stream: TcpStream }, PassiveConnecting { stream: TcpStream, channel_id: ChannelId }, ActiveRecving { endpoint: Endpoint }, } let mut channel_index_stream = ChannelIndexStream::default(); let mut todos = self .bindings .iter() .enumerate() .map(|(index, binding)| { Ok(Some(match binding.coupling { Coupling::Passive => { let channel_index = channel_index_stream.next(); let channel_id = ChannelId { controller_id, channel_index }; let listener = TcpListener::bind(&binding.addr).map_err(|_| BindErr(binding.addr))?; ms.poll.register(&listener, Token(index), ready_r, edge).unwrap(); // registration unique Todo::PassiveAccepting { listener, channel_id } } Coupling::Active => { let stream = TcpStream::connect(&binding.addr) .map_err(|_| NewSocketErr(binding.addr))?; ms.poll.register(&stream, Token(index), ready_w, edge).unwrap(); // registration unique Todo::ActiveConnecting { stream } } })) }) .collect::>, ConnectErr>>()?; let mut num_todos_remaining = todos.len(); // 4. handle incoming events until all TODOs are completed OR we timeout let deadline = timeout.map(|t| Instant::now() + t); let mut polled_undrained_later = IndexSet::<_>::default(); let mut backoff_millis = 10; while num_todos_remaining > 0 { ms.poll_events_until(deadline)?; for event in ms.events.iter() { let token = event.token(); let index = token.0; let binding = &self.bindings[index]; match todos[index].take() { None => { polled_undrained_later.insert(index); } Some(Todo::PassiveAccepting { listener, channel_id }) => { let (stream, _peer_addr) = listener.accept().map_err(|_| AcceptErr(binding.addr))?; ms.poll.deregister(&listener).expect("wer"); ms.poll.register(&stream, token, ready_w, edge).expect("3y5"); todos[index] = Some(Todo::PassiveConnecting { stream, channel_id }); } Some(Todo::ActiveConnecting { mut stream }) => { let todo = if Self::test_stream_connectivity(&mut stream) { ms.poll.reregister(&stream, token, ready_r, edge).expect("52"); let endpoint = Endpoint::from_fresh_stream(stream); Todo::ActiveRecving { endpoint } } else { ms.poll.deregister(&stream).expect("wt"); std::thread::sleep(Duration::from_millis(backoff_millis)); backoff_millis = ((backoff_millis as f32) * 1.2) as u64 + 3; let stream = TcpStream::connect(&binding.addr).unwrap(); ms.poll.register(&stream, token, ready_w, edge).expect("PAC 3"); Todo::ActiveConnecting { stream } }; todos[index] = Some(todo); } Some(Todo::PassiveConnecting { mut stream, channel_id }) => { if !Self::test_stream_connectivity(&mut stream) { return Err(ConnectionShutdown(binding.addr)); } ms.poll.reregister(&stream, token, ready_r, edge).expect("55"); let polarity = binding.polarity; let info = EndpointInfo { polarity, channel_id }; let msg = Msg::SetupMsg(SetupMsg::ChannelSetup { info }); let mut endpoint = Endpoint::from_fresh_stream(stream); endpoint.send(msg).map_err(|e| EndpointErr(Port(index), e))?; let endpoint_ext = EndpointExt { endpoint, info }; endpoint_exts.occupy_reserved(index, endpoint_ext); num_todos_remaining -= 1; } Some(Todo::ActiveRecving { mut endpoint }) => { let ekey = Port(index); 'recv_loop: while let Some(msg) = endpoint.recv().map_err(|e| EndpointErr(ekey, e))? { if let Msg::SetupMsg(SetupMsg::ChannelSetup { info }) = msg { if info.polarity == binding.polarity { return Err(PortKindMismatch(ekey, binding.addr)); } let channel_id = info.channel_id; let info = EndpointInfo { polarity: binding.polarity, channel_id }; ms.polled_undrained.insert(ekey); let endpoint_ext = EndpointExt { endpoint, info }; endpoint_exts.occupy_reserved(index, endpoint_ext); num_todos_remaining -= 1; break 'recv_loop; } else { ms.delayed.push(ReceivedMsg { recipient: ekey, msg }); } } } } } } assert_eq!(None, endpoint_exts.iter_reserved().next()); drop(todos); /////////////////////////////////////////////////////// // 1. construct `family', i.e. perform the sink tree setup procedure use {Msg::SetupMsg as S, SetupMsg::*}; let mut messenger = (&mut ms, &mut endpoint_exts); impl Messengerlike for (&mut MessengerState, &mut VecStorage) { fn get_state_mut(&mut self) -> &mut MessengerState { self.0 } fn get_endpoint_mut(&mut self, ekey: Key) -> &mut Endpoint { &mut self .1 .get_occupied_mut(ekey.to_raw() as usize) .expect("OUT OF BOUNDS") .endpoint } } // 1. broadcast my ID as the first echo. await reply from all in net_keylist let neighbors = (0..self.bindings.len()).map(Port); let echo = S(LeaderEcho { maybe_leader: controller_id }); let mut awaiting = IndexSet::::with_capacity(neighbors.len()); for n in neighbors.clone() { messenger.send(n, echo.clone()).map_err(|e| EndpointErr(n, e))?; awaiting.insert(n); } // 2. Receive incoming replies. whenever a higher-id echo arrives, // adopt it as leader, sender as parent, and reset the await set. let mut parent: Option = None; let mut my_leader = controller_id; messenger.undelay_all(); 'echo_loop: while !awaiting.is_empty() || parent.is_some() { let ReceivedMsg { recipient, msg } = messenger.recv_until(deadline)?.ok_or(Timeout)?; match msg { S(LeaderAnnounce { leader }) => { // someone else completed the echo and became leader first! // the sender is my parent parent = Some(recipient); my_leader = leader; awaiting.clear(); break 'echo_loop; } S(LeaderEcho { maybe_leader }) => { use Ordering::*; match maybe_leader.cmp(&my_leader) { Less => { /* ignore */ } Equal => { awaiting.remove(&recipient); if awaiting.is_empty() { if let Some(p) = parent { // return the echo to my parent messenger .send(p, S(LeaderEcho { maybe_leader })) .map_err(|e| EndpointErr(p, e))?; } else { // DECIDE! break 'echo_loop; } } } Greater => { // join new echo parent = Some(recipient); my_leader = maybe_leader; let echo = S(LeaderEcho { maybe_leader: my_leader }); awaiting.clear(); if neighbors.len() == 1 { // immediately reply to parent messenger .send(recipient, echo.clone()) .map_err(|e| EndpointErr(recipient, e))?; } else { for n in neighbors.clone() { if n != recipient { messenger .send(n, echo.clone()) .map_err(|e| EndpointErr(n, e))?; awaiting.insert(n); } } } } } } msg => messenger.delay(ReceivedMsg { recipient, msg }), } } match parent { None => assert_eq!( my_leader, controller_id, "I've got no parent, but I consider {:?} the leader?", my_leader ), Some(parent) => assert_ne!( my_leader, controller_id, "I have {:?} as parent, but I consider myself ({:?}) the leader?", parent, controller_id ), } // 3. broadcast leader announcement (except to parent: confirm they are your parent) // in this loop, every node sends 1 message to each neighbor let msg_for_non_parents = S(LeaderAnnounce { leader: my_leader }); for n in neighbors.clone() { let msg = if Some(n) == parent { S(YouAreMyParent) } else { msg_for_non_parents.clone() }; messenger.send(n, msg).map_err(|e| EndpointErr(n, e))?; } // await 1 message from all non-parents for n in neighbors.clone() { if Some(n) != parent { awaiting.insert(n); } } let mut children = HashSet::default(); messenger.undelay_all(); while !awaiting.is_empty() { let ReceivedMsg { recipient, msg } = messenger.recv_until(deadline)?.ok_or(Timeout)?; let recipient = recipient; match msg { S(YouAreMyParent) => { assert!(awaiting.remove(&recipient)); children.insert(recipient); } S(SetupMsg::LeaderAnnounce { leader }) => { assert!(awaiting.remove(&recipient)); assert!(leader == my_leader); assert!(Some(recipient) != parent); // they wouldn't send me this if they considered me their parent } _ => messenger.delay(ReceivedMsg { recipient, msg }), } } let family = Family { parent, children }; // done! Ok(Connected { components: Default::default(), controller_id, channel_index_stream, endpoint_exts, native_ports, family, }) } ///////// pub fn connect_using_id( &mut self, controller_id: ControllerId, timeout: Option, ) -> Result { // 1. try and create a connection from these bindings with self immutable. let connected = self.new_connected(controller_id, timeout)?; // 2. success! drain self and return self.bindings.clear(); Ok(connected) } pub fn connect(&mut self, timeout: Option) -> Result { self.connect_using_id(Self::random_controller_id(), timeout) } } #[derive(Debug)] pub struct Connected { native_ports: HashSet, controller_id: ControllerId, channel_index_stream: ChannelIndexStream, endpoint_exts: VecStorage, components: VecStorage, family: Family, } impl Connected { pub fn new_component( &mut self, protocol: &Arc, identifier: &Arc<[u8]>, moved_port_list: &[Port], ) -> Result<(), MainComponentErr> { ////////////////////////////////////////// // 1. try and create a new component (without mutating self) use MainComponentErr::*; let moved_port_set = { let mut set: HashSet = Default::default(); for &port in moved_port_list.iter() { if !self.native_ports.contains(&port) { return Err(CannotMovePort(port)); } if !set.insert(port) { return Err(DuplicateMovedPort(port)); } } set }; // moved_port_set is disjoint to native_ports let expected_polarities = protocol.component_polarities(identifier)?; if moved_port_list.len() != expected_polarities.len() { return Err(WrongNumberOfParamaters { expected: expected_polarities.len() }); } // correct polarity list for (param_index, (&port, &expected_polarity)) in moved_port_list.iter().zip(expected_polarities.iter()).enumerate() { let polarity = self.endpoint_exts.get_occupied(port.0).ok_or(UnknownPort(port))?.info.polarity; if polarity != expected_polarity { return Err(WrongPortPolarity { param_index, port }); } } let state = protocol.new_main_component(identifier, &moved_port_list); let component = Component { port_set: moved_port_set, protocol: protocol.clone(), identifier: identifier.clone(), state, }; ////////////////////////////// // success! mutate self and return Ok self.native_ports.retain(|e| !component.port_set.contains(e)); self.components.new_occupied(component); Ok(()) } pub fn new_channel(&mut self) -> (OutPort, InPort) { assert!(self.endpoint_exts.len() <= std::u32::MAX as usize - 2); let channel_id = ChannelId { controller_id: self.controller_id, channel_index: self.channel_index_stream.next(), }; let [e0, e1] = Endpoint::new_memory_pair(); let kp = self.endpoint_exts.new_occupied(EndpointExt { info: EndpointInfo { channel_id, polarity: Putter }, endpoint: e0, }); let kg = self.endpoint_exts.new_occupied(EndpointExt { info: EndpointInfo { channel_id, polarity: Getter }, endpoint: e1, }); (OutPort(Port(kp)), InPort(Port(kg))) } pub fn sync_set(&mut self, _inbuf: &mut [u8], _ops: &mut [PortOpRs]) -> Result<(), ()> { Ok(()) } pub fn sync_subsets( &mut self, _inbuf: &mut [u8], _ops: &mut [PortOpRs], bit_subsets: &[&[usize]], ) -> Result { for (batch_index, bit_subset) in bit_subsets.iter().enumerate() { println!("batch_index {:?}", batch_index); let chunk_iter = bit_subset.iter().copied(); for index in BitChunkIter::new(chunk_iter) { println!(" index {:?}", index); } } Ok(0) } } macro_rules! bitslice { ($( $num:expr ),*) => {{ &[0 $( | (1usize << $num) )*] }}; } #[test] fn api_new_test() { let mut c = Connecting::default(); let net_out: OutPort = c.bind(Coupling::Active, "127.0.0.1:8000".parse().unwrap()); let net_in: InPort = c.bind(Coupling::Active, "127.0.0.1:8001".parse().unwrap()); let proto_0 = Arc::new(ProtocolD::parse(b"").unwrap()); let mut c = c.connect(None).unwrap(); let (mem_out, mem_in) = c.new_channel(); let mut inbuf = [0u8; 64]; let identifier: Arc<[u8]> = b"sync".to_vec().into(); c.new_component(&proto_0, &identifier, &[net_in.into(), mem_out.into()]).unwrap(); let mut ops = [ PortOpRs::In { msg_range: None, port: &mem_in }, PortOpRs::Out { msg: b"hey", port: &net_out, optional: false }, PortOpRs::Out { msg: b"hi?", port: &net_out, optional: true }, PortOpRs::Out { msg: b"yo!", port: &net_out, optional: false }, ]; c.sync_set(&mut inbuf, &mut ops).unwrap(); c.sync_subsets(&mut inbuf, &mut ops, &[bitslice! {0,1,2}]).unwrap(); } #[repr(C)] pub struct PortOp { msgptr: *mut u8, // read if OUT, field written if IN, will point into buf msglen: usize, // read if OUT, written if IN, won't exceed buf port: Port, optional: bool, // no meaning if } pub enum PortOpRs<'a> { In { msg_range: Option>, port: &'a InPort }, Out { msg: &'a [u8], port: &'a OutPort, optional: bool }, } unsafe fn c_sync_set( connected: &mut Connected, inbuflen: usize, inbufptr: *mut u8, opslen: usize, opsptr: *mut PortOp, ) -> i32 { let buf = as_mut_slice(inbuflen, inbufptr); let ops = as_mut_slice(opslen, opsptr); let (subset_index, wrote) = sync_inner(connected, buf); assert_eq!(0, subset_index); for op in ops { if let Some(range) = wrote.get(&op.port) { op.msgptr = inbufptr.add(range.start); op.msglen = range.end - range.start; } } 0 } use super::bits::{usizes_for_bits, BitChunkIter}; unsafe fn c_sync_subset( connected: &mut Connected, inbuflen: usize, inbufptr: *mut u8, opslen: usize, opsptr: *mut PortOp, subsetslen: usize, subsetsptr: *const *const usize, ) -> i32 { let buf: &mut [u8] = as_mut_slice(inbuflen, inbufptr); let ops: &mut [PortOp] = as_mut_slice(opslen, opsptr); let subsets: &[*const usize] = as_const_slice(subsetslen, subsetsptr); let subsetlen = usizes_for_bits(opslen); // don't yet know subsetptr; which subset fires unknown! let (subset_index, wrote) = sync_inner(connected, buf); let subsetptr: *const usize = subsets[subset_index]; let subset: &[usize] = as_const_slice(subsetlen, subsetptr); for index in BitChunkIter::new(subset.iter().copied()) { let op = &mut ops[index as usize]; if let Some(range) = wrote.get(&op.port) { op.msgptr = inbufptr.add(range.start); op.msglen = range.end - range.start; } } subset_index as i32 } // dummy fn for the actual synchronous round fn sync_inner<'c, 'b>( _connected: &'c mut Connected, _buf: &'b mut [u8], ) -> (usize, &'b HashMap>) { todo!() } unsafe fn as_mut_slice<'a, T>(len: usize, ptr: *mut T) -> &'a mut [T] { std::slice::from_raw_parts_mut(ptr, len) } unsafe fn as_const_slice<'a, T>(len: usize, ptr: *const T) -> &'a [T] { std::slice::from_raw_parts(ptr, len) } #[test] fn api_connecting() { let addrs: [SocketAddr; 3] = [ "127.0.0.1:8888".parse().unwrap(), "127.0.0.1:8889".parse().unwrap(), "127.0.0.1:8890".parse().unwrap(), ]; const TIMEOUT: Option = Some(Duration::from_secs(1)); let handles = vec![ std::thread::spawn(move || { let mut connecting = Connecting::default(); let _a: OutPort = connecting.bind(Coupling::Passive, addrs[0]); let _b: OutPort = connecting.bind(Coupling::Active, addrs[1]); let connected = connecting.connect(TIMEOUT); println!("A: {:#?}", connected); }), std::thread::spawn(move || { let mut connecting = Connecting::default(); let _a: InPort = connecting.bind(Coupling::Active, addrs[0]); let _b: InPort = connecting.bind(Coupling::Passive, addrs[1]); let _c: InPort = connecting.bind(Coupling::Active, addrs[2]); let connected = connecting.connect(TIMEOUT); println!("B: {:#?}", connected); }), std::thread::spawn(move || { let mut connecting = Connecting::default(); let _a: OutPort = connecting.bind(Coupling::Passive, addrs[2]); let connected = connecting.connect(TIMEOUT); println!("C: {:#?}", connected); }), ]; for h in handles { h.join().unwrap(); } }