diff --git a/src/constants.rs b/src/constants.rs index 8d014ed..2392c02 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -8,6 +8,11 @@ pub const ICMP_ECHO_REPLY: u8 = 0; pub const ICMP6_ECHO_REQUEST: u8 = 128; pub const ICMP6_ECHO_REPLY: u8 = 129; pub const ICMP_HEADER_LEN: usize = 8; +pub const ICMP_TYPE_OFFSET: usize = 0; +pub const ICMP_ID_OFFSET: usize = 4; +pub const ICMP_SEQ_OFFSET: usize = 6; +pub const ICMP_MIN_LEN: usize = 8; +pub const IPV4_MIN_HDR_LEN: usize = 20; /* Default values */ pub const _DEFAULT_INTERVAL_MS: u64 = 10; @@ -24,6 +29,7 @@ pub const _MAX_GENERATE: usize = 131_072; pub const _MAX_TARGET_NAME: usize = 255; /* Response flags */ +pub const RESP_TIMES_CAP: usize = 1000; pub const _RESP_WAITING: i64 = -1; pub const _RESP_UNUSED: i64 = -2; pub const _RESP_TIMEOUT: i64 = -4; diff --git a/src/socket.rs b/src/socket.rs index b8641e7..34ee9fa 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -2,11 +2,13 @@ use libc::{ c_void, recvfrom, sendto, sockaddr, socklen_t, AF_INET, AF_INET6, IPPROTO_ICMP, IPPROTO_ICMPV6, SOCK_DGRAM, SOCK_RAW, }; +use std::io::ErrorKind; use std::net::{Ipv4Addr, Ipv6Addr}; use std::os::unix::io::RawFd; use crate::constants::{ ICMP6_ECHO_REPLY, ICMP6_ECHO_REQUEST, ICMP_ECHO_REPLY, ICMP_ECHO_REQUEST, ICMP_HEADER_LEN, + ICMP_TYPE_OFFSET, ICMP_ID_OFFSET, ICMP_SEQ_OFFSET, ICMP_MIN_LEN, IPV4_MIN_HDR_LEN, }; #[derive(Clone, Copy, PartialEq, Eq, Debug)] @@ -24,15 +26,25 @@ pub fn open_raw_socket(is_ipv6: bool) -> Result<(RawFd, SocketKind, Option) let fd = unsafe { libc::socket(domain, SOCK_RAW, proto) }; if fd >= 0 { - set_nonblocking(fd); + if let Err(e) = set_nonblocking(fd) { + eprintln!("Warning: set_nonblocking failed for RAW socket: {}", e); + } return Ok((fd, SocketKind::Raw, None)); } let fd = unsafe { libc::socket(domain, SOCK_DGRAM, proto) }; if fd >= 0 { - set_nonblocking(fd); - let assigned_id = dgram_bind_and_get_id(fd, is_ipv6); - return Ok((fd, SocketKind::Dgram, assigned_id)); + if let Err(e) = set_nonblocking(fd) { + eprintln!("Warning: set_nonblocking failed for DGRAM socket: {}", e); + } + + match dgram_bind_and_get_id(fd, is_ipv6) { + Ok(assigned_id) => return Ok((fd, SocketKind::Dgram, assigned_id)), + Err(e) => { + unsafe { libc::close(fd) }; + return Err(format!("DGRAM socket bind/getsockname failed: {}", e)); + } + } } Err(format!( @@ -44,7 +56,7 @@ pub fn open_raw_socket(is_ipv6: bool) -> Result<(RawFd, SocketKind, Option) )) } -fn dgram_bind_and_get_id(fd: RawFd, is_ipv6: bool) -> Option { +fn dgram_bind_and_get_id(fd: RawFd, is_ipv6: bool) -> Result, std::io::Error> { unsafe { if is_ipv6 { let mut sa: libc::sockaddr_in6 = std::mem::zeroed(); @@ -54,13 +66,20 @@ fn dgram_bind_and_get_id(fd: RawFd, is_ipv6: bool) -> Option { &sa as *const _ as *const libc::sockaddr, std::mem::size_of::() as socklen_t, ); - if r < 0 { return None; } + if r < 0 { + return Err(std::io::Error::last_os_error()); + } let mut sa2: libc::sockaddr_in6 = std::mem::zeroed(); let mut len = std::mem::size_of::() as socklen_t; let r = libc::getsockname(fd, &mut sa2 as *mut _ as *mut libc::sockaddr, &mut len); - if r < 0 || sa2.sin6_port == 0 { return None; } - Some(u16::from_be(sa2.sin6_port)) + if r < 0 { + return Err(std::io::Error::last_os_error()); + } + if sa2.sin6_port == 0 { + return Ok(None); + } + Ok(Some(u16::from_be(sa2.sin6_port))) } else { let mut sa: libc::sockaddr_in = std::mem::zeroed(); sa.sin_family = AF_INET as libc::sa_family_t; @@ -69,38 +88,52 @@ fn dgram_bind_and_get_id(fd: RawFd, is_ipv6: bool) -> Option { &sa as *const _ as *const libc::sockaddr, std::mem::size_of::() as socklen_t, ); - if r < 0 { return None; } + if r < 0 { + return Err(std::io::Error::last_os_error()); + } let mut sa2: libc::sockaddr_in = std::mem::zeroed(); let mut len = std::mem::size_of::() as socklen_t; let r = libc::getsockname(fd, &mut sa2 as *mut _ as *mut libc::sockaddr, &mut len); - if r < 0 || sa2.sin_port == 0 { return None; } - Some(u16::from_be(sa2.sin_port)) + if r < 0 { + return Err(std::io::Error::last_os_error()); + } + if sa2.sin_port == 0 { + return Ok(None); + } + Ok(Some(u16::from_be(sa2.sin_port))) } } } -fn set_nonblocking(fd: RawFd) { +fn set_nonblocking(fd: RawFd) -> Result<(), std::io::Error> { unsafe { let flags = libc::fcntl(fd, libc::F_GETFL, 0); - libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK); + if flags < 0 { + return Err(std::io::Error::last_os_error()); + } + let r = libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK); + if r < 0 { + return Err(std::io::Error::last_os_error()); + } + Ok(()) } } -pub fn build_icmp_packet(id: u16, seq: u16, data_size: usize, is_ipv6: bool, _kind: SocketKind) -> Vec { +pub fn build_icmp_packet(id: u16, seq: u16, data_size: usize, is_ipv6: bool, kind: SocketKind) -> Vec { let total = ICMP_HEADER_LEN + data_size; let mut pkt = vec![0u8; total]; - pkt[0] = if is_ipv6 { ICMP6_ECHO_REQUEST } else { ICMP_ECHO_REQUEST }; - pkt[1] = 0; - pkt[2] = 0; - pkt[3] = 0; - pkt[4] = (id >> 8) as u8; - pkt[5] = (id & 0xFF) as u8; - pkt[6] = (seq >> 8) as u8; - pkt[7] = (seq & 0xFF) as u8; + pkt[ICMP_TYPE_OFFSET] = if is_ipv6 { ICMP6_ECHO_REQUEST } else { ICMP_ECHO_REQUEST }; + pkt[1] = 0; // Code + pkt[2] = 0; // Checksum high + pkt[3] = 0; // Checksum low + pkt[ICMP_ID_OFFSET] = (id >> 8) as u8; + pkt[ICMP_ID_OFFSET + 1] = (id & 0xFF) as u8; + pkt[ICMP_SEQ_OFFSET] = (seq >> 8) as u8; + pkt[ICMP_SEQ_OFFSET + 1] = (seq & 0xFF) as u8; - for (i, b) in pkt[8..].iter_mut().enumerate() { + for (i, b) in pkt[ICMP_MIN_LEN..].iter_mut().enumerate() { *b = (i & 0xFF) as u8; } @@ -108,6 +141,8 @@ pub fn build_icmp_packet(id: u16, seq: u16, data_size: usize, is_ipv6: bool, _ki let cksum = icmp_checksum(&pkt); pkt[2] = (cksum >> 8) as u8; pkt[3] = (cksum & 0xFF) as u8; + } else if kind == SocketKind::Raw { + eprintln!("Warning: ICMPv6 checksum not computed for RAW socket – packet may be dropped"); } pkt @@ -170,6 +205,45 @@ pub struct ReceivedPing { pub raw_len: usize, } +fn parse_icmp_packet(data: &[u8], is_ipv6: bool, kind: SocketKind, expected_id: Option) -> Option { + let icmp = if !is_ipv6 && kind == SocketKind::Raw { + if data.len() < IPV4_MIN_HDR_LEN + ICMP_MIN_LEN { + return None; + } + let ihl = ((data[0] & 0x0F) as usize) * 4; + if data.len() < ihl + ICMP_MIN_LEN { + return None; + } + &data[ihl..] + } else { + if data.len() < ICMP_MIN_LEN { + return None; + } + data + }; + + let icmp_type = icmp[ICMP_TYPE_OFFSET]; + + let is_reply = if is_ipv6 { + icmp_type == ICMP6_ECHO_REPLY + } else { + icmp_type == ICMP_ECHO_REPLY + }; + + if !is_reply { + return None; + } + + let id = u16::from_be_bytes([icmp[ICMP_ID_OFFSET], icmp[ICMP_ID_OFFSET + 1]]); + if let Some(eid) = expected_id { + if id != eid { + return None; + } + } + + Some(u16::from_be_bytes([icmp[ICMP_SEQ_OFFSET], icmp[ICMP_SEQ_OFFSET + 1]])) +} + pub fn recv_ping(fd: RawFd, buf: &mut [u8], is_ipv6: bool, kind: SocketKind, expected_id: Option) -> Option { let mut src: libc::sockaddr_storage = unsafe { std::mem::zeroed() }; let mut src_len = std::mem::size_of::() as socklen_t; @@ -186,36 +260,16 @@ pub fn recv_ping(fd: RawFd, buf: &mut [u8], is_ipv6: bool, kind: SocketKind, exp }; if n < 0 { + let err = std::io::Error::last_os_error(); + if err.kind() != ErrorKind::WouldBlock { + eprintln!("recv_ping: recvfrom error: {}", err); + } return None; } let raw_len = n as usize; let data = &buf[..raw_len]; + let seq = parse_icmp_packet(data, is_ipv6, kind, expected_id)?; - let icmp = if !is_ipv6 && kind == SocketKind::Raw { - if data.len() < 20 + 8 { return None; } - let ihl = ((data[0] & 0x0F) as usize) * 4; - if data.len() < ihl + 8 { return None; } - &data[ihl..] - } else { - if data.len() < 8 { return None; } - data - }; - - let icmp_type = icmp[0]; - let is_reply = icmp_type == ICMP_ECHO_REPLY || icmp_type == ICMP6_ECHO_REPLY; - if !is_reply { - return None; - } - - let id = u16::from_be_bytes([icmp[4], icmp[5]]); - - if let Some(eid) = expected_id { - if id != eid { return None; } - } - - Some(ReceivedPing { - seq: u16::from_be_bytes([icmp[6], icmp[7]]), - raw_len, - }) + Some(ReceivedPing { seq, raw_len }) } \ No newline at end of file diff --git a/src/types.rs b/src/types.rs index c9dcc2e..118194a 100644 --- a/src/types.rs +++ b/src/types.rs @@ -1,5 +1,6 @@ use std::net::IpAddr; use std::time::{Duration, Instant}; +use crate::constants::RESP_TIMES_CAP; #[derive(Debug)] pub struct HostEntry { @@ -57,8 +58,12 @@ impl HostEntry { // count-mode: pre-allocated slot exists, write in place self.resp_times[idx] = Some(rtt); } else { - // loop/default-mode: resp_times is empty or shorter than ping_index, append dynamically so -s statistics are never empty - self.resp_times.push(Some(rtt)); + if self.resp_times.len() < RESP_TIMES_CAP { + self.resp_times.push(Some(rtt)); + } else { + let slot = (ping_index as usize) % RESP_TIMES_CAP; + self.resp_times[slot] = Some(rtt); + } } }