diff --git a/src/ffi/socket_api.rs b/src/ffi/socket_api.rs index 8e701d1c72ff1c5cc800e2dbca35c14a3150757f..2d38ed9abaaec00b28dafddffa09428c918240b2 100644 --- a/src/ffi/socket_api.rs +++ b/src/ffi/socket_api.rs @@ -1,23 +1,41 @@ use super::*; use atomic_refcell::AtomicRefCell; -use std::{collections::HashMap, ffi::c_void, net::SocketAddr, os::raw::c_int, sync::RwLock}; +use std::{ + collections::HashMap, + ffi::c_void, + net::{Ipv4Addr, SocketAddr, SocketAddrV4}, + os::raw::c_int, + sync::RwLock, +}; /////////////////////////////////////////////////////////////////// struct FdAllocator { next: Option, freed: Vec, } -enum MaybeConnector { - New, - Bound { local_addr: SocketAddr }, - Connected { connector: Connector, putter: PortId, getter: PortId }, +struct ConnectorBound { + connector: Connector, + putter: PortId, + getter: PortId, +} +struct MaybeConnector { + // invariants: + // 1. connector is a upd-socket singleton + // 2. putter and getter are ports in the native interface with the appropriate polarities + // 3. peer_addr always mirrors connector's single udp socket's connect addr. both are overwritten together. + peer_addr: SocketAddr, + connector_bound: Option, } #[derive(Default)] -struct ConnectorStorage { - fd_to_connector: HashMap>, +struct CspStorage { + fd_to_mc: HashMap>, fd_allocator: FdAllocator, } +fn trivial_peer_addr() -> SocketAddr { + // SocketAddrV4::new isn't a constant-time func + SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::new(0, 0, 0, 0), 0)) +} /////////////////////////////////////////////////////////////////// impl Default for FdAllocator { @@ -44,25 +62,26 @@ impl FdAllocator { } } lazy_static::lazy_static! { - static ref CONNECTOR_STORAGE: RwLock = Default::default(); + static ref CSP_STORAGE: RwLock = Default::default(); } /////////////////////////////////////////////////////////////////// #[no_mangle] pub extern "C" fn rw_socket(_domain: c_int, _type: c_int) -> c_int { - // assuming _domain is AF_INET and _type is SOCK_DGRAM - let mut w = if let Ok(w) = CONNECTOR_STORAGE.write() { w } else { return FD_LOCK_POISONED }; + // ignoring domain and type + let mut w = if let Ok(w) = CSP_STORAGE.write() { w } else { return FD_LOCK_POISONED }; let fd = w.fd_allocator.alloc(); - w.fd_to_connector.insert(fd, AtomicRefCell::new(MaybeConnector::New)); + let mc = MaybeConnector { peer_addr: trivial_peer_addr(), connector_bound: None }; + w.fd_to_mc.insert(fd, AtomicRefCell::new(mc)); fd } #[no_mangle] pub extern "C" fn rw_close(fd: c_int, _how: c_int) -> c_int { // ignoring HOW - let mut w = if let Ok(w) = CONNECTOR_STORAGE.write() { w } else { return FD_LOCK_POISONED }; + let mut w = if let Ok(w) = CSP_STORAGE.write() { w } else { return FD_LOCK_POISONED }; w.fd_allocator.free(fd); - if w.fd_to_connector.remove(&fd).is_some() { + if w.fd_to_mc.remove(&fd).is_some() { ERR_OK } else { CLOSE_FAIL @@ -75,13 +94,22 @@ pub unsafe extern "C" fn rw_bind( local_addr: *const SocketAddr, _addr_len: usize, ) -> c_int { - use MaybeConnector as Mc; // assuming _domain is AF_INET and _type is SOCK_DGRAM - let r = if let Ok(r) = CONNECTOR_STORAGE.read() { r } else { return FD_LOCK_POISONED }; - let mc = if let Some(mc) = r.fd_to_connector.get(&fd) { mc } else { return BAD_FD }; - let mc: &mut Mc = &mut mc.borrow_mut(); - let _ = if let Mc::New = mc { () } else { return WRONG_STATE }; - *mc = Mc::Bound { local_addr: local_addr.read() }; + let r = if let Ok(r) = CSP_STORAGE.read() { r } else { return FD_LOCK_POISONED }; + let mc = if let Some(mc) = r.fd_to_mc.get(&fd) { mc } else { return BAD_FD }; + let mc: &mut MaybeConnector = &mut mc.borrow_mut(); + if mc.connector_bound.is_some() { + return WRONG_STATE; + } + mc.connector_bound = { + let mut connector = Connector::new( + Box::new(crate::DummyLogger), + crate::TRIVIAL_PD.clone(), + Connector::random_id(), + ); + let [putter, getter] = connector.new_udp_port(local_addr.read(), mc.peer_addr).unwrap(); + Some(ConnectorBound { connector, putter, getter }) + }; ERR_OK } @@ -91,27 +119,19 @@ pub unsafe extern "C" fn rw_connect( peer_addr: *const SocketAddr, _address_len: usize, ) -> c_int { - use MaybeConnector as Mc; // assuming _domain is AF_INET and _type is SOCK_DGRAM - let r = if let Ok(r) = CONNECTOR_STORAGE.read() { r } else { return FD_LOCK_POISONED }; - let mc = if let Some(mc) = r.fd_to_connector.get(&fd) { mc } else { return BAD_FD }; - let mc: &mut Mc = &mut mc.borrow_mut(); - let local_addr = - if let Mc::Bound { local_addr } = mc { local_addr } else { return WRONG_STATE }; - let peer_addr = peer_addr.read(); - let (connector, [putter, getter]) = { - let mut c = Connector::new( - Box::new(DummyLogger), - crate::TRIVIAL_PD.clone(), - Connector::random_id(), - 8, - ); - let [putter, getter] = c.new_udp_port(*local_addr, peer_addr).unwrap(); - (c, [putter, getter]) - }; - *mc = Mc::Connected { connector, putter, getter }; + let r = if let Ok(r) = CSP_STORAGE.read() { r } else { return FD_LOCK_POISONED }; + let mc = if let Some(mc) = r.fd_to_mc.get(&fd) { mc } else { return BAD_FD }; + let mc: &mut MaybeConnector = &mut mc.borrow_mut(); + mc.peer_addr = peer_addr.read(); + if let Some(ConnectorBound { connector, .. }) = &mut mc.connector_bound { + if connector.get_mut_udp_sock(0).unwrap().connect(mc.peer_addr).is_err() { + return CONNECT_FAILED; + } + } ERR_OK } + #[no_mangle] pub unsafe extern "C" fn rw_send( fd: c_int, @@ -119,22 +139,18 @@ pub unsafe extern "C" fn rw_send( bytes_len: usize, _flags: c_int, ) -> isize { - use MaybeConnector as Mc; // ignoring flags - let r = - if let Ok(r) = CONNECTOR_STORAGE.read() { r } else { return FD_LOCK_POISONED as isize }; - let mc = if let Some(mc) = r.fd_to_connector.get(&fd) { mc } else { return BAD_FD as isize }; - let mc: &mut Mc = &mut mc.borrow_mut(); - let (connector, putter) = if let Mc::Connected { connector, putter, .. } = mc { - (connector, *putter) + let r = if let Ok(r) = CSP_STORAGE.read() { r } else { return FD_LOCK_POISONED as isize }; + let mc = if let Some(mc) = r.fd_to_mc.get(&fd) { mc } else { return BAD_FD as isize }; + let mc: &mut MaybeConnector = &mut mc.borrow_mut(); + if let Some(ConnectorBound { connector, putter, .. }) = &mut mc.connector_bound { + match connector_put_bytes(connector, *putter, bytes_ptr as _, bytes_len) { + ERR_OK => connector_sync(connector, -1), + err => err as isize, + } } else { - return WRONG_STATE as isize; - }; - match connector_put_bytes(connector, putter, bytes_ptr as _, bytes_len) { - ERR_OK => {} - err => return err as isize, + WRONG_STATE as isize // not bound! } - connector_sync(connector, -1) } #[no_mangle] @@ -144,27 +160,23 @@ pub unsafe extern "C" fn rw_recv( bytes_len: usize, _flags: c_int, ) -> isize { - use MaybeConnector as Mc; // ignoring flags - let r = - if let Ok(r) = CONNECTOR_STORAGE.read() { r } else { return FD_LOCK_POISONED as isize }; - let mc = if let Some(mc) = r.fd_to_connector.get(&fd) { mc } else { return BAD_FD as isize }; - let mc: &mut Mc = &mut mc.borrow_mut(); - let (connector, getter) = if let Mc::Connected { connector, getter, .. } = mc { - (connector, *getter) + let r = if let Ok(r) = CSP_STORAGE.read() { r } else { return FD_LOCK_POISONED as isize }; + let mc = if let Some(mc) = r.fd_to_mc.get(&fd) { mc } else { return BAD_FD as isize }; + let mc: &mut MaybeConnector = &mut mc.borrow_mut(); + if let Some(ConnectorBound { connector, getter, .. }) = &mut mc.connector_bound { + connector_get(connector, *getter); + match connector_sync(connector, -1) { + 0 => { + // batch index 0 means OK + let slice = connector.gotten(*getter).unwrap().as_slice(); + let copied_bytes = slice.len().min(bytes_len); + std::ptr::copy_nonoverlapping(slice.as_ptr(), bytes_ptr as *mut u8, copied_bytes); + copied_bytes as isize + } + err => return err as isize, + } } else { - return WRONG_STATE as isize; - }; - match connector_get(connector, getter) { - ERR_OK => {} - err => return err as isize, + WRONG_STATE as isize // not bound! } - match connector_sync(connector, -1) { - 0 => {} // singleton batch index - err => return err as isize, - }; - let slice = connector.gotten(getter).unwrap().as_slice(); - let copied_bytes = slice.len().min(bytes_len); - std::ptr::copy_nonoverlapping(slice.as_ptr(), bytes_ptr as *mut u8, copied_bytes); - copied_bytes as isize }