Files @ fb814548c7d5
Branch filter:

Location: CSY/reowolf/src/runtime2/stdlib/internet.rs

fb814548c7d5 9.0 KiB application/rls-services+xml Show Annotation Show as Raw Download as Raw
mh
Add tcp component to standard library
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::mem::size_of;
use std::io::{Error as IoError, ErrorKind as IoErrorKind};

use libc::{
    c_int,
    sockaddr_in, sockaddr_in6, in_addr, in6_addr,
    socket, bind, listen, accept, connect, close,
};
use mio::{event, Interest, Registry, Token};

use crate::runtime2::poll::{AsFileDescriptor, FileDescriptor};

#[derive(Debug)]
pub enum SocketError {
    Opening,
    Modifying,
    Binding,
    Listening,
    Connecting,
    Accepted,
    Accepting,
}

enum SocketState {
    Opened,
    Listening,
}

/// TCP connection
pub struct SocketTcpClient {
    socket_handle: libc::c_int,
    is_blocking: bool,
}

impl SocketTcpClient {
    pub fn new(ip: IpAddr, port: u16) -> Result<Self, SocketError> {
        const BLOCKING: bool = false;

        let socket_handle = create_and_connect_socket(
            libc::SOCK_STREAM, libc::IPPROTO_TCP, ip, port
        )?;
        if !set_socket_blocking(socket_handle, BLOCKING) {
            unsafe{ libc::close(socket_handle); }
            return Err(SocketError::Modifying);
        }

        return Ok(SocketTcpClient{
            socket_handle,
            is_blocking: BLOCKING,
        })
    }

    pub fn send(&self, message: &[u8]) -> Result<usize, IoError> {
        let result = unsafe{
            let message_pointer = message.as_ptr().cast();
            libc::send(self.socket_handle, message_pointer, message.len() as libc::size_t, 0)
        };
        if result < 0 {
            return Err(IoError::last_os_error());
        }

        return Ok(result as usize);
    }

    /// Receives data from the TCP socket. Returns the number of bytes received.
    /// More bytes may be present even thought `used < buffer.len()`.
    pub fn receive(&self, buffer: &mut [u8]) -> Result<usize, IoError> {
        let result = unsafe {
            let message_pointer = buffer.as_mut_ptr().cast();
            libc::recv(self.socket_handle, message_pointer, buffer.len(), 0)
        };
        if result < 0 {
            return Err(IoError::last_os_error());
        }

        return Ok(result as usize);
    }
}

impl Drop for SocketTcpClient {
    fn drop(&mut self) {
        debug_assert!(self.socket_handle >= 0);
        unsafe{ close(self.socket_handle) };
    }
}

impl AsFileDescriptor for SocketTcpClient {
    fn as_file_descriptor(&self) -> FileDescriptor {
        return self.socket_handle;
    }
}

/// Raw socket receiver. Essentially a listener that accepts a single connection
struct SocketRawRx {
    listen_handle: c_int,
    accepted_handle: c_int,
}

impl SocketRawRx {
    pub fn new(ip: Option<Ipv4Addr>, port: u16) -> Result<Self, SocketError> {
        let ip = ip.unwrap_or(Ipv4Addr::UNSPECIFIED); // unspecified is the same as INADDR_ANY
        let address = unsafe{ in_addr{
            s_addr: std::mem::transmute(ip.octets()),
        }};
        let socket_address = sockaddr_in{
            sin_family: libc::AF_INET as libc::sa_family_t,
            sin_port: htons(port),
            sin_addr: address,
            sin_zero: [0; 8],
        };

        unsafe {
            let socket_handle = create_and_bind_socket(libc::SOCK_RAW, 0, IpAddr::V4(ip), port)?;

            let result = listen(socket_handle, 3);
            if result < 0 { return Err(SocketError::Listening); }

            return Ok(SocketRawRx{
                listen_handle: socket_handle,
                accepted_handle: -1,
            });
        }
    }

    // pub fn try_accept(&mut self, timeout_ms: u32) -> Result<(), SocketError> {
    //     if self.accepted_handle >= 0 {
    //         // Already accepted a connection
    //         return Err(SocketError::Accepted);
    //     }
    //
    //     let mut socket_address = sockaddr_in{
    //         sin_family: 0,
    //         sin_port: 0,
    //         sin_addr: in_addr{ s_addr: 0 },
    //         sin_zero: [0; 8]
    //     };
    //     let mut size = size_of::<sockaddr_in>() as u32;
    //     unsafe {
    //         let result = accept(self.listen_handle, &mut socket_address as *mut _, &mut size as *mut _);
    //         if result < 0 {
    //             return Err(SocketError::Accepting);
    //         }
    //     }
    //
    //     return Ok(());
    // }
}

impl Drop for SocketRawRx {
    fn drop(&mut self) {
        if self.accepted_handle >= 0 {
            unsafe {
                close(self.accepted_handle);
            }
        }

        if self.listen_handle >= 0 {
            unsafe {
                close(self.listen_handle);
            }
        }
    }
}

// The following is essentially stolen from `mio`'s io_source.rs file.
#[cfg(unix)]
trait AsRawFileDescriptor {
    fn as_raw_file_descriptor(&self) -> c_int;
}

impl AsRawFileDescriptor for SocketTcpClient {
    fn as_raw_file_descriptor(&self) -> c_int {
        return self.socket_handle;
    }
}

/// Performs the `socket` and `bind` calls.
fn create_and_bind_socket(socket_type: libc::c_int, protocol: libc::c_int, ip: IpAddr, port: u16) -> Result<libc::c_int, SocketError> {
    let family = socket_family_from_ip(ip);

    unsafe {
        let socket_handle = socket(family, socket_type, protocol);
        if socket_handle < 0 {
            return Err(SocketError::Opening);
        }

        let result = match ip {
            IpAddr::V4(ip) => {
                let (socket_address, address_size) = create_sockaddr_in_v4(ip, port);
                let socket_pointer = &socket_address as *const sockaddr_in;
                bind(socket_handle, socket_pointer.cast(), address_size)
            },
            IpAddr::V6(ip) => {
                let (socket_address, address_size) = create_sockaddr_in_v6(ip, port);
                let socket_pointer= &socket_address as *const sockaddr_in6;
                bind(socket_handle, socket_pointer.cast(), address_size)
            }
        };
        if result < 0 {
            close(socket_handle);
            return Err(SocketError::Binding);
        }

        return Ok(socket_handle);
    }
}

/// Performs the `socket` and `connect` calls
fn create_and_connect_socket(socket_type: libc::c_int, protocol: libc::c_int, ip: IpAddr, port: u16) -> Result<libc::c_int, SocketError> {
    let family = socket_family_from_ip(ip);
    unsafe {
        let socket_handle = socket(family, socket_type, protocol);
        if socket_handle < 0 {
            return Err(SocketError::Opening);
        }

        let result = match ip {
            IpAddr::V4(ip) => {
                let (socket_address, address_size) = create_sockaddr_in_v4(ip, port);
                let socket_pointer = &socket_address as *const sockaddr_in;
                connect(socket_handle, socket_pointer.cast(), address_size)
            },
            IpAddr::V6(ip) => {
                let (socket_address, address_size) = create_sockaddr_in_v6(ip, port);
                let socket_pointer= &socket_address as *const sockaddr_in6;
                connect(socket_handle, socket_pointer.cast(), address_size)
            }
        };
        if result < 0 {
            close(socket_handle);
            return Err(SocketError::Connecting);
        }

        return Ok(socket_handle);
    }
}

#[inline]
fn create_sockaddr_in_v4(ip: Ipv4Addr, port: u16) -> (sockaddr_in, libc::socklen_t) {
    let address = unsafe{
        in_addr{
            s_addr: std::mem::transmute(ip.octets())
        }
    };

    let socket_address = sockaddr_in{
        sin_family: libc::AF_INET as libc::sa_family_t,
        sin_port: htons(port),
        sin_addr: address,
        sin_zero: [0; 8]
    };
    let address_size = size_of::<sockaddr_in>();

    return (socket_address, address_size as _);
}

#[inline]
fn create_sockaddr_in_v6(ip: Ipv6Addr, port: u16) -> (sockaddr_in6, libc::socklen_t) {
    // flow label is advised to be, according to RFC6437 a (somewhat
    // secure) random number taken from a uniform distribution
    let flow_info = rand::random();

    let address = unsafe{
        in6_addr{
            s6_addr: ip.octets()
        }
    };

    let socket_address = sockaddr_in6{
        sin6_family: libc::AF_INET6 as libc::sa_family_t,
        sin6_port: htons(port),
        sin6_flowinfo: flow_info,
        sin6_addr: address,
        sin6_scope_id: 0, // incorrect in case of loopback address
    };
    let address_size = size_of::<sockaddr_in6>();

    return (socket_address, address_size as _);
}

#[inline]
fn set_socket_blocking(handle: libc::c_int, blocking: bool) -> bool {
    if handle < 0 {
        return false;
    }

    unsafe{
        let mut flags = libc::fcntl(handle, libc::F_GETFL, 0);
        if flags < 0 {
            return false;
        }

        if blocking {
            flags &= !libc::O_NONBLOCK;
        } else {
            flags |= libc::O_NONBLOCK;
        }

        let result = libc::fcntl(handle, libc::F_SETFL, flags);
        if result < 0 {
            return false;
        }
    }

    return true;
}

#[inline]
fn socket_family_from_ip(ip: IpAddr) -> libc::c_int {
    return match ip {
        IpAddr::V4(_) => libc::AF_INET,
        IpAddr::V6(_) => libc::AF_INET6,
    };
}

#[inline]
fn htons(port: u16) -> u16 {
    return port.to_be();
}