Files @ dc1a7211fdca
Branch filter:

Location: CSY/reowolf/src/runtime2/stdlib/internet.rs

dc1a7211fdca 7.9 KiB application/rls-services+xml Show Annotation Show as Raw Download as Raw
MH
Initial tcp socket implementation
use std::net::{IpAddr, Ipv4Addr, Ipv6Addr};
use std::mem::size_of;

use libc::{
    c_int,
    sockaddr_in, sockaddr_in6, in_addr, in6_addr,
    socket, bind, listen, accept, connect, close,
};

#[derive(Debug)]
pub enum SocketError {
    Opening,
    Binding,
    Listening,
    Connecting,
    Accepted,
    Accepting,
}

enum SocketState {
    Opened,
    Listening,
}

/// TCP connection
pub struct SocketTcpClient {
    socket_handle: libc::c_int
}

impl SocketTcpClient {
    pub fn new(ip: IpAddr, port: u16) -> Result<Self, SocketError> {
        let socket_handle = create_and_connect_socket(
            libc::SOCK_STREAM, libc::IPPROTO_TCP, ip, port
        )?;

        return Ok(SocketTcpClient{
            socket_handle,
        })
    }

    pub fn send(&self, message: &[u8]) -> Result<usize, ()> {
        let result = unsafe{
            let message_pointer = message.as_ptr().cast();
            libc::send(self.socket_handle, message_pointer, message.len() as libc::size_t, 0)
        };
        if result < 0 {
            return Err(())
        }

        return Ok(result as usize);
    }

    pub fn receive(&self, buffer: &mut [u8]) -> Result<usize, ()> {
        let result = unsafe {
            let message_pointer = buffer.as_mut_ptr().cast();
            libc::recv(self.socket_handle, message_pointer, buffer.len() as libc::size_t, 0)
        };
        if result < 0 {
            return Err(())
        }

        return Ok(result as usize);
    }
}

impl Drop for SocketTcpClient {
    fn drop(&mut self) {
        debug_assert!(self.socket_handle >= 0);
        unsafe{ close(self.socket_handle) };
    }
}

/// Raw socket receiver. Essentially a listener that accepts a single connection
struct SocketRawRx {
    listen_handle: c_int,
    accepted_handle: c_int,
}

impl SocketRawRx {
    pub fn new(ip: Option<Ipv4Addr>, port: u16) -> Result<Self, SocketError> {
        let ip = ip.unwrap_or(Ipv4Addr::UNSPECIFIED); // unspecified is the same as INADDR_ANY
        let address = unsafe{ in_addr{
            s_addr: std::mem::transmute(ip.octets()),
        }};
        let socket_address = sockaddr_in{
            sin_family: libc::AF_INET as libc::sa_family_t,
            sin_port: htons(port),
            sin_addr: address,
            sin_zero: [0; 8],
        };

        unsafe {
            let socket_handle = create_and_bind_socket(libc::SOCK_RAW, 0, IpAddr::V4(ip), port)?;

            let result = listen(socket_handle, 3);
            if result < 0 { return Err(SocketError::Listening); }

            return Ok(SocketRawRx{
                listen_handle: socket_handle,
                accepted_handle: -1,
            });
        }
    }

    // pub fn try_accept(&mut self, timeout_ms: u32) -> Result<(), SocketError> {
    //     if self.accepted_handle >= 0 {
    //         // Already accepted a connection
    //         return Err(SocketError::Accepted);
    //     }
    //
    //     let mut socket_address = sockaddr_in{
    //         sin_family: 0,
    //         sin_port: 0,
    //         sin_addr: in_addr{ s_addr: 0 },
    //         sin_zero: [0; 8]
    //     };
    //     let mut size = size_of::<sockaddr_in>() as u32;
    //     unsafe {
    //         let result = accept(self.listen_handle, &mut socket_address as *mut _, &mut size as *mut _);
    //         if result < 0 {
    //             return Err(SocketError::Accepting);
    //         }
    //     }
    //
    //     return Ok(());
    // }
}

impl Drop for SocketRawRx {
    fn drop(&mut self) {
        if self.accepted_handle >= 0 {
            unsafe {
                close(self.accepted_handle);
            }
        }

        if self.listen_handle >= 0 {
            unsafe {
                close(self.listen_handle);
            }
        }
    }
}



/// Performs the `socket` and `bind` calls.
fn create_and_bind_socket(socket_type: libc::c_int, protocol: libc::c_int, ip: IpAddr, port: u16) -> Result<libc::c_int, SocketError> {
    let family = socket_family_from_ip(ip);

    unsafe {
        let socket_handle = socket(family, socket_type, protocol);
        if socket_handle < 0 {
            return Err(SocketError::Opening);
        }

        let result = match ip {
            IpAddr::V4(ip) => {
                let (socket_address, address_size) = create_sockaddr_in_v4(ip, port);
                let socket_pointer = &socket_address as *const sockaddr_in;
                bind(socket_handle, socket_pointer.cast(), address_size)
            },
            IpAddr::V6(ip) => {
                let (socket_address, address_size) = create_sockaddr_in_v6(ip, port);
                let socket_pointer= &socket_address as *const sockaddr_in6;
                bind(socket_handle, socket_pointer.cast(), address_size)
            }
        };
        if result < 0 {
            close(socket_handle);
            return Err(SocketError::Binding);
        }

        return Ok(socket_handle);
    }
}

/// Performs the `socket` and `connect` calls
fn create_and_connect_socket(socket_type: libc::c_int, protocol: libc::c_int, ip: IpAddr, port: u16) -> Result<libc::c_int, SocketError> {
    let family = socket_family_from_ip(ip);
    unsafe {
        let socket_handle = socket(family, socket_type, protocol);
        if socket_handle < 0 {
            return Err(SocketError::Opening);
        }

        let result = match ip {
            IpAddr::V4(ip) => {
                let (socket_address, address_size) = create_sockaddr_in_v4(ip, port);
                let socket_pointer = &socket_address as *const sockaddr_in;
                connect(socket_handle, socket_pointer.cast(), address_size)
            },
            IpAddr::V6(ip) => {
                let (socket_address, address_size) = create_sockaddr_in_v6(ip, port);
                let socket_pointer= &socket_address as *const sockaddr_in6;
                connect(socket_handle, socket_pointer.cast(), address_size)
            }
        };
        if result < 0 {
            close(socket_handle);
            return Err(SocketError::Connecting);
        }

        return Ok(socket_handle);
    }
}

#[inline]
fn create_sockaddr_in_v4(ip: Ipv4Addr, port: u16) -> (sockaddr_in, libc::socklen_t) {
    let address = unsafe{
        in_addr{
            s_addr: std::mem::transmute(ip.octets())
        }
    };

    let socket_address = sockaddr_in{
        sin_family: libc::AF_INET as libc::sa_family_t,
        sin_port: htons(port),
        sin_addr: address,
        sin_zero: [0; 8]
    };
    let address_size = size_of::<sockaddr_in>();

    return (socket_address, address_size as _);
}

#[inline]
fn create_sockaddr_in_v6(ip: Ipv6Addr, port: u16) -> (sockaddr_in6, libc::socklen_t) {
    // flow label is advised to be, according to RFC6437 a (somewhat
    // secure) random number taken from a uniform distribution
    let flow_info = rand::random();

    let address = unsafe{
        in6_addr{
            s6_addr: ip.octets()
        }
    };

    let socket_address = sockaddr_in6{
        sin6_family: libc::AF_INET6 as libc::sa_family_t,
        sin6_port: htons(port),
        sin6_flowinfo: flow_info,
        sin6_addr: address,
        sin6_scope_id: 0, // incorrect in case of loopback address
    };
    let address_size = size_of::<sockaddr_in6>();

    return (socket_address, address_size as _);
}

#[inline]
fn socket_family_from_ip(ip: IpAddr) -> libc::c_int {
    return match ip {
        IpAddr::V4(_) => libc::AF_INET,
        IpAddr::V6(_) => libc::AF_INET6,
    };
}

#[inline]
fn htons(port: u16) -> u16 {
    return port.to_be();
}

mod tests {
    use std::net::*;
    use super::*;

    #[test]
    fn test_inet_thingo() {
        let s = SocketTcpClient::new(IpAddr::V4(Ipv4Addr::new(142, 250, 179, 163)), 80).expect("connect");
        s.send(b"GET / HTTP/1.1\r\n\r\n").expect("sending");
        let mut buffer = [0;65000];
        s.receive(&mut buffer).expect("receiving");
        let as_str = String::from_utf8_lossy(&buffer);
        println!("Yay! Got:\n{}", as_str);
    }
}