diff --git a/src/common.rs b/src/common.rs index f64a289d9c192138cca32e3d48dabe45fab72302..21f3d91b2c6977da112626c46870f0687e451248 100644 --- a/src/common.rs +++ b/src/common.rs @@ -112,6 +112,11 @@ impl From for ProtoComponentId { Self(id) } } +impl From<&[u8]> for Payload { + fn from(s: &[u8]) -> Payload { + Payload(Arc::new(s.to_vec())) + } +} impl Payload { pub fn new(len: usize) -> Payload { let mut v = Vec::with_capacity(len); diff --git a/src/runtime/communication.rs b/src/runtime/communication.rs index 0b3d617a931edbbd4ee660bfcace2562c46801bc..839cc539d37d3e5071e143f6af6f952206b034a6 100644 --- a/src/runtime/communication.rs +++ b/src/runtime/communication.rs @@ -85,6 +85,26 @@ impl SyncProtoContext<'_> { } impl Connector { + pub fn put(&mut self, port: PortId, payload: Payload) -> Result<(), PortOpError> { + use PortOpError::*; + if !self.native_ports.contains(&port) { + return Err(PortUnavailable); + } + if Putter != *self.port_info.polarities.get(&port).unwrap() { + return Err(WrongPolarity); + } + match &mut self.phased { + ConnectorPhased::Setup { .. } => Err(NotConnected), + ConnectorPhased::Communication { native_batches, .. } => { + let batch = native_batches.last_mut().unwrap(); + if batch.to_put.contains_key(&port) { + return Err(MultipleOpsOnPort); + } + batch.to_put.insert(port, payload); + Ok(()) + } + } + } pub fn sync(&mut self) -> Result { use SyncError::*; match &mut self.phased { diff --git a/src/runtime/error.rs b/src/runtime/error.rs index 672a3d6554cb00feecfe4579beed4fbe9111c55b..7de8323c2e0d468fbcf03a6b55ef1b8dd1d0b437 100644 --- a/src/runtime/error.rs +++ b/src/runtime/error.rs @@ -19,3 +19,10 @@ pub enum SyncError { InconsistentProtoComponent(ProtoComponentId), IndistinguishableBatches([usize; 2]), } +#[derive(Debug)] +pub enum PortOpError { + WrongPolarity, + NotConnected, + MultipleOpsOnPort, + PortUnavailable, +} diff --git a/src/runtime/my_tests.rs b/src/runtime/my_tests.rs index d59d4007dfae04c3bb2d03a627ca054e9c4d495e..d702b34971338a724d540e61cbbae627f120e604 100644 --- a/src/runtime/my_tests.rs +++ b/src/runtime/my_tests.rs @@ -88,3 +88,28 @@ fn multithreaded_connect() { }) .unwrap(); } + +#[test] +fn put_no_sync() { + let mut c = Connector::new_simple(MINIMAL_PROTO.clone(), 0); + let [o, _] = c.new_port_pair(); + c.connect(Duration::from_secs(1)).unwrap(); + c.put(o, (b"hi" as &[_]).into()).unwrap(); +} + +#[test] +fn wrong_polarity_bad() { + let mut c = Connector::new_simple(MINIMAL_PROTO.clone(), 0); + let [_, i] = c.new_port_pair(); + c.connect(Duration::from_secs(1)).unwrap(); + c.put(i, (b"hi" as &[_]).into()).unwrap_err(); +} + +#[test] +fn dup_put_bad() { + let mut c = Connector::new_simple(MINIMAL_PROTO.clone(), 0); + let [o, _] = c.new_port_pair(); + c.connect(Duration::from_secs(1)).unwrap(); + c.put(o, (b"hi" as &[_]).into()).unwrap(); + c.put(o, (b"hi" as &[_]).into()).unwrap_err(); +}