diff --git a/src/runtime2/stdlib/internet.rs b/src/runtime2/stdlib/internet.rs index 9fe43a02c06d383d2e5c3296b3bda3bdff52f136..0013332ee7c5c7b50eae94a319637d642214ad80 100644 --- a/src/runtime2/stdlib/internet.rs +++ b/src/runtime2/stdlib/internet.rs @@ -10,6 +10,7 @@ use libc::{ #[derive(Debug)] pub enum SocketError { Opening, + Modifying, Binding, Listening, Connecting, @@ -24,17 +25,25 @@ enum SocketState { /// TCP connection pub struct SocketTcpClient { - socket_handle: libc::c_int + socket_handle: libc::c_int, + is_blocking: bool, } impl SocketTcpClient { pub fn new(ip: IpAddr, port: u16) -> Result { + const BLOCKING: bool = false; + let socket_handle = create_and_connect_socket( libc::SOCK_STREAM, libc::IPPROTO_TCP, ip, port )?; + if !set_socket_blocking(socket_handle, BLOCKING) { + unsafe{ libc::close(socket_handle); } + return Err(SocketError::Modifying); + } return Ok(SocketTcpClient{ socket_handle, + is_blocking: BLOCKING, }) } @@ -50,17 +59,59 @@ impl SocketTcpClient { return Ok(result as usize); } + /// Receives data from the TCP socket. Returns the number of bytes received. + /// More bytes may be present even thought `used < buffer.len()`. pub fn receive(&self, buffer: &mut [u8]) -> Result { + if self.is_blocking { + return self.receive_blocking(buffer); + } else { + return self.receive_nonblocking(buffer); + } + } + + #[inline] + fn receive_blocking(&self, buffer: &mut [u8]) -> Result { let result = unsafe { let message_pointer = buffer.as_mut_ptr().cast(); - libc::recv(self.socket_handle, message_pointer, buffer.len() as libc::size_t, 0) + libc::recv(self.socket_handle, message_pointer, buffer.len(), 0) }; if result < 0 { - return Err(()) + return Err(()); } return Ok(result as usize); } + + #[inline] + fn receive_nonblocking(&self, buffer: &mut [u8]) -> Result { + unsafe { + let mut message_pointer = buffer.as_mut_ptr().cast(); + let mut remaining = buffer.len(); + + loop { + // Receive more data + let result = libc::recv(self.socket_handle, message_pointer, remaining, 0); + if result < 0 { + // Check reason + let errno = std::io::Error::last_os_error().raw_os_error().expect("os error after failed recv"); + if errno == libc::EWOULDBLOCK || errno == libc::EAGAIN { + return Ok(buffer.len() - remaining); + } else { + return Err(()); + } + } + + // Modify pointer and remaining bytes + let received = result as usize; + message_pointer = message_pointer.add(received); + remaining -= received; + + if remaining == 0 { + return Ok(buffer.len()); + } + } + } + } } impl Drop for SocketTcpClient { @@ -248,6 +299,33 @@ fn create_sockaddr_in_v6(ip: Ipv6Addr, port: u16) -> (sockaddr_in6, libc::sockle return (socket_address, address_size as _); } +#[inline] +fn set_socket_blocking(handle: libc::c_int, blocking: bool) -> bool { + if handle < 0 { + return false; + } + + unsafe{ + let mut flags = libc::fcntl(handle, libc::F_GETFL, 0); + if flags < 0 { + return false; + } + + if blocking { + flags &= !libc::O_NONBLOCK; + } else { + flags |= libc::O_NONBLOCK; + } + + let result = libc::fcntl(handle, libc::F_SETFL, flags); + if result < 0 { + return false; + } + } + + return true; +} + #[inline] fn socket_family_from_ip(ip: IpAddr) -> libc::c_int { return match ip { @@ -267,11 +345,20 @@ mod tests { #[test] fn test_inet_thingo() { + const SIZE: usize = 1024; + let s = SocketTcpClient::new(IpAddr::V4(Ipv4Addr::new(142, 250, 179, 163)), 80).expect("connect"); s.send(b"GET / HTTP/1.1\r\n\r\n").expect("sending"); - let mut buffer = [0;65000]; - s.receive(&mut buffer).expect("receiving"); - let as_str = String::from_utf8_lossy(&buffer); - println!("Yay! Got:\n{}", as_str); + let mut total = Vec::::new(); + let mut buffer = [0; SIZE]; + let mut received = SIZE; + + while received > 0 { + received = s.receive(&mut buffer).expect("receiving"); + println!("DEBUG: Received {} bytes", received); + total.extend_from_slice(&buffer[..received]); + } + let as_str = String::from_utf8_lossy(total.as_slice()); + println!("Yay! Got {} bytes:\n{}", as_str.len(), as_str); } } \ No newline at end of file