diff --git a/src/ffi/socket_api.rs b/src/ffi/socket_api.rs index 2d38ed9abaaec00b28dafddffa09428c918240b2..b297ea144d5e0792c805aeabf998d55cb9d78bd9 100644 --- a/src/ffi/socket_api.rs +++ b/src/ffi/socket_api.rs @@ -1,5 +1,4 @@ use super::*; -use atomic_refcell::AtomicRefCell; use std::{ collections::HashMap, @@ -29,7 +28,7 @@ struct MaybeConnector { } #[derive(Default)] struct CspStorage { - fd_to_mc: HashMap>, + fd_to_mc: HashMap>, fd_allocator: FdAllocator, } fn trivial_peer_addr() -> SocketAddr { @@ -64,6 +63,48 @@ impl FdAllocator { lazy_static::lazy_static! { static ref CSP_STORAGE: RwLock = Default::default(); } +impl MaybeConnector { + fn connect(&mut self, peer_addr: SocketAddr) -> c_int { + self.peer_addr = peer_addr; + if let Some(ConnectorBound { connector, .. }) = &mut self.connector_bound { + if connector.get_mut_udp_sock(0).unwrap().connect(peer_addr).is_err() { + return CONNECT_FAILED; + } + } + ERR_OK + } + unsafe fn send(&mut self, bytes_ptr: *const c_void, bytes_len: usize) -> isize { + if let Some(ConnectorBound { connector, putter, .. }) = &mut self.connector_bound { + match connector_put_bytes(connector, *putter, bytes_ptr as _, bytes_len) { + ERR_OK => connector_sync(connector, -1), + err => err as isize, + } + } else { + WRONG_STATE as isize // not bound! + } + } + unsafe fn recv(&mut self, bytes_ptr: *const c_void, bytes_len: usize) -> isize { + if let Some(ConnectorBound { connector, getter, .. }) = &mut self.connector_bound { + connector_get(connector, *getter); + match connector_sync(connector, -1) { + 0 => { + // batch index 0 means OK + let slice = connector.gotten(*getter).unwrap().as_slice(); + let copied_bytes = slice.len().min(bytes_len); + std::ptr::copy_nonoverlapping( + slice.as_ptr(), + bytes_ptr as *mut u8, + copied_bytes, + ); + copied_bytes as isize + } + err => return err as isize, + } + } else { + WRONG_STATE as isize // not bound! + } + } +} /////////////////////////////////////////////////////////////////// #[no_mangle] @@ -72,7 +113,7 @@ pub extern "C" fn rw_socket(_domain: c_int, _type: c_int) -> c_int { let mut w = if let Ok(w) = CSP_STORAGE.write() { w } else { return FD_LOCK_POISONED }; let fd = w.fd_allocator.alloc(); let mc = MaybeConnector { peer_addr: trivial_peer_addr(), connector_bound: None }; - w.fd_to_mc.insert(fd, AtomicRefCell::new(mc)); + w.fd_to_mc.insert(fd, RwLock::new(mc)); fd } @@ -97,7 +138,8 @@ pub unsafe extern "C" fn rw_bind( // assuming _domain is AF_INET and _type is SOCK_DGRAM let r = if let Ok(r) = CSP_STORAGE.read() { r } else { return FD_LOCK_POISONED }; let mc = if let Some(mc) = r.fd_to_mc.get(&fd) { mc } else { return BAD_FD }; - let mc: &mut MaybeConnector = &mut mc.borrow_mut(); + let mut mc = if let Ok(mc) = mc.write() { mc } else { return FD_LOCK_POISONED }; + let mc: &mut MaybeConnector = &mut mc; if mc.connector_bound.is_some() { return WRONG_STATE; } @@ -122,14 +164,9 @@ pub unsafe extern "C" fn rw_connect( // assuming _domain is AF_INET and _type is SOCK_DGRAM let r = if let Ok(r) = CSP_STORAGE.read() { r } else { return FD_LOCK_POISONED }; let mc = if let Some(mc) = r.fd_to_mc.get(&fd) { mc } else { return BAD_FD }; - let mc: &mut MaybeConnector = &mut mc.borrow_mut(); - mc.peer_addr = peer_addr.read(); - if let Some(ConnectorBound { connector, .. }) = &mut mc.connector_bound { - if connector.get_mut_udp_sock(0).unwrap().connect(mc.peer_addr).is_err() { - return CONNECT_FAILED; - } - } - ERR_OK + let mut mc = if let Ok(mc) = mc.write() { mc } else { return FD_LOCK_POISONED }; + let mc: &mut MaybeConnector = &mut mc; + mc.connect(peer_addr.read()) } #[no_mangle] @@ -142,15 +179,9 @@ pub unsafe extern "C" fn rw_send( // ignoring flags let r = if let Ok(r) = CSP_STORAGE.read() { r } else { return FD_LOCK_POISONED as isize }; let mc = if let Some(mc) = r.fd_to_mc.get(&fd) { mc } else { return BAD_FD as isize }; - let mc: &mut MaybeConnector = &mut mc.borrow_mut(); - if let Some(ConnectorBound { connector, putter, .. }) = &mut mc.connector_bound { - match connector_put_bytes(connector, *putter, bytes_ptr as _, bytes_len) { - ERR_OK => connector_sync(connector, -1), - err => err as isize, - } - } else { - WRONG_STATE as isize // not bound! - } + let mut mc = if let Ok(mc) = mc.write() { mc } else { return FD_LOCK_POISONED as isize }; + let mc: &mut MaybeConnector = &mut mc; + mc.send(bytes_ptr, bytes_len) } #[no_mangle] @@ -163,20 +194,68 @@ pub unsafe extern "C" fn rw_recv( // ignoring flags let r = if let Ok(r) = CSP_STORAGE.read() { r } else { return FD_LOCK_POISONED as isize }; let mc = if let Some(mc) = r.fd_to_mc.get(&fd) { mc } else { return BAD_FD as isize }; - let mc: &mut MaybeConnector = &mut mc.borrow_mut(); - if let Some(ConnectorBound { connector, getter, .. }) = &mut mc.connector_bound { - connector_get(connector, *getter); - match connector_sync(connector, -1) { - 0 => { - // batch index 0 means OK - let slice = connector.gotten(*getter).unwrap().as_slice(); - let copied_bytes = slice.len().min(bytes_len); - std::ptr::copy_nonoverlapping(slice.as_ptr(), bytes_ptr as *mut u8, copied_bytes); - copied_bytes as isize - } - err => return err as isize, - } - } else { - WRONG_STATE as isize // not bound! + let mut mc = if let Ok(mc) = mc.write() { mc } else { return FD_LOCK_POISONED as isize }; + let mc: &mut MaybeConnector = &mut mc; + mc.recv(bytes_ptr, bytes_len) +} + +#[no_mangle] +pub unsafe extern "C" fn rw_sendto( + fd: c_int, + bytes_ptr: *mut c_void, + bytes_len: usize, + _flags: c_int, + peer_addr: *const SocketAddr, + _addr_len: usize, +) -> isize { + let r = if let Ok(r) = CSP_STORAGE.read() { r } else { return FD_LOCK_POISONED as isize }; + let mc = if let Some(mc) = r.fd_to_mc.get(&fd) { mc } else { return BAD_FD as isize }; + let mut mc = if let Ok(mc) = mc.write() { mc } else { return FD_LOCK_POISONED as isize }; + let mc: &mut MaybeConnector = &mut mc; + // copy currently connected peer addr + let connected = mc.peer_addr; + // connect to given peer_addr + match mc.connect(peer_addr.read()) { + e if e != ERR_OK => return e as isize, + _ => {} + } + // send + let ret = mc.send(bytes_ptr, bytes_len); + // restore connected peer addr + match mc.connect(connected) { + e if e != ERR_OK => return e as isize, + _ => {} + } + ret +} + +#[no_mangle] +#[no_mangle] +pub unsafe extern "C" fn rw_recvfrom( + fd: c_int, + bytes_ptr: *mut c_void, + bytes_len: usize, + _flags: c_int, + peer_addr: *const SocketAddr, + _addr_len: usize, +) -> isize { + let r = if let Ok(r) = CSP_STORAGE.read() { r } else { return FD_LOCK_POISONED as isize }; + let mc = if let Some(mc) = r.fd_to_mc.get(&fd) { mc } else { return BAD_FD as isize }; + let mut mc = if let Ok(mc) = mc.write() { mc } else { return FD_LOCK_POISONED as isize }; + let mc: &mut MaybeConnector = &mut mc; + // copy currently connected peer addr + let connected = mc.peer_addr; + // connect to given peer_addr + match mc.connect(peer_addr.read()) { + e if e != ERR_OK => return e as isize, + _ => {} + } + // send + let ret = mc.send(bytes_ptr, bytes_len); + // restore connected peer addr + match mc.connect(connected) { + e if e != ERR_OK => return e as isize, + _ => {} } + ret }