From 5b5ce9fde7c3999fc81691254ddb893936b7375d Mon Sep 17 00:00:00 2001 From: Ryuta Kambe Date: Thu, 3 Jul 2025 15:57:54 +0900 Subject: [PATCH] feat(virtio): tx checksum offloading Signed-off-by: Ryuta Kambe --- awkernel_drivers/src/pcie/virtio.rs | 2 + .../src/pcie/virtio/virtio_net.rs | 105 +++++++++++++++--- 2 files changed, 89 insertions(+), 18 deletions(-) diff --git a/awkernel_drivers/src/pcie/virtio.rs b/awkernel_drivers/src/pcie/virtio.rs index b9799d679..212efa5da 100644 --- a/awkernel_drivers/src/pcie/virtio.rs +++ b/awkernel_drivers/src/pcie/virtio.rs @@ -33,6 +33,7 @@ pub enum VirtioDriverErr { NoVirtqueue, InvalidQueueSize, DMAPool, + InvalidPacket, } impl From for PCIeDeviceErr { @@ -46,6 +47,7 @@ impl From for PCIeDeviceErr { VirtioDriverErr::NoVirtqueue => PCIeDeviceErr::InitFailure, VirtioDriverErr::InvalidQueueSize => PCIeDeviceErr::InitFailure, VirtioDriverErr::DMAPool => PCIeDeviceErr::InitFailure, + VirtioDriverErr::InvalidPacket => PCIeDeviceErr::CommandFailure, } } } diff --git a/awkernel_drivers/src/pcie/virtio/virtio_net.rs b/awkernel_drivers/src/pcie/virtio/virtio_net.rs index cbcc9950d..69bcba20c 100644 --- a/awkernel_drivers/src/pcie/virtio/virtio_net.rs +++ b/awkernel_drivers/src/pcie/virtio/virtio_net.rs @@ -21,8 +21,15 @@ use awkernel_lib::{ addr::Addr, dma_pool::DMAPool, interrupt::IRQ, - net::net_device::{ - EtherFrameBuf, EtherFrameRef, LinkStatus, NetCapabilities, NetDevError, NetDevice, NetFlags, + net::{ + ether::{extract_headers, EtherHeader, EtherVlanHeader, NetworkHdr, TransportHdr}, + ipv6::Ip6Hdr, + net_device::{ + EtherFrameBuf, EtherFrameRef, LinkStatus, NetCapabilities, NetDevError, NetDevice, + NetFlags, PacketHeaderFlags, + }, + tcp::TCPHdr, + udp::UDPHdr, }, paging::PAGESIZE, sync::{ @@ -30,6 +37,7 @@ use awkernel_lib::{ rwlock::RwLock, }, }; +use memoffset::offset_of; const DEVICE_SHORT_NAME: &str = "virtio-net"; @@ -38,6 +46,7 @@ const RECV_QUEUE_SIZE: usize = 32; // To Be Determined const VIRTIO_NET_ID: u16 = 0x1041; // device-specific feature bits +const VIRTIO_NET_F_CSUM: u64 = 1 << 0; const VIRTIO_NET_F_MAC: u64 = 1 << 5; const VIRTIO_NET_F_STATUS: u64 = 1 << 16; const VIRTIO_NET_F_SPEED_DUPLEX: u64 = 1 << 63; @@ -274,6 +283,59 @@ impl Virtq { } } + fn vio_tx_offload(&mut self, frame: &EtherFrameRef) -> Result { + let mut hdr = VirtioNetHdr::default(); + + let has_tcp_csum_out = frame.csum_flags.contains(PacketHeaderFlags::TCP_CSUM_OUT); + let has_udp_csum_out = frame.csum_flags.contains(PacketHeaderFlags::UDP_CSUM_OUT); + if !has_tcp_csum_out && !has_udp_csum_out { + return Ok(hdr); + } + + let ext = extract_headers(frame.data).or(Err(VirtioDriverErr::InvalidPacket))?; + + // Consistency check + match ext.network { + NetworkHdr::Ipv4(_) => (), + NetworkHdr::Ipv6(_) => (), + _ => return Ok(hdr), + } + match ext.transport { + TransportHdr::Tcp(_) => { + if !has_tcp_csum_out { + return Ok(hdr); + } + } + TransportHdr::Udp(_) => { + if !has_udp_csum_out { + return Ok(hdr); + } + } + _ => return Ok(hdr), + } + + hdr.csum_start = match ext.ether_vlan { + Some(_) => core::mem::size_of::() as u16, + None => core::mem::size_of::() as u16, + }; + + hdr.csum_start += match ext.network { + NetworkHdr::Ipv4(ip) => ip.header_len() as u16, + NetworkHdr::Ipv6(_) => core::mem::size_of::() as u16, + _ => 0, + }; + + hdr.csum_offset = match ext.transport { + TransportHdr::Tcp(_) => offset_of!(TCPHdr, th_sum) as u16, + TransportHdr::Udp(_) => offset_of!(UDPHdr, uh_sum) as u16, + _ => 0, + }; + + hdr.flags = VIRTIO_NET_HDR_F_NEEDS_CSUM; + + Ok(hdr) + } + fn vio_tx_dequeue(&mut self) -> u16 { let mut freed = 0; while let Some((slot, _len)) = self.virtio_dequeue() { @@ -295,25 +357,25 @@ impl Virtq { self.vio_txeof(); } - fn vio_encap(&mut self, slot: usize, frame: &EtherFrameRef) -> usize { - let len = frame.data.len(); + fn vio_encap(&mut self, slot: usize, frame: &EtherFrameRef, header: &VirtioNetHdr) -> usize { let buf = self.data_buf.as_mut(); let dst = &mut buf[slot].as_mut_ptr(); let header_len = core::mem::size_of::(); + let data_len = frame.data.len(); unsafe { - // TODO: handle VirtIO-net header - // For now, we just skip the header by dst.add(header_len) - core::ptr::copy_nonoverlapping(frame.data.as_ptr(), dst.add(header_len), len); + core::ptr::copy_nonoverlapping(header as *const _ as *const u8, dst.add(0), header_len); + core::ptr::copy_nonoverlapping(frame.data.as_ptr(), dst.add(header_len), data_len); } - header_len + len + header_len + data_len } - fn vio_start(&mut self, frame: &EtherFrameRef) { + fn vio_start(&mut self, frame: &EtherFrameRef) -> Result<(), VirtioDriverErr> { self.vio_tx_dequeue(); if let Some(slot) = self.virtio_enqueue_prep() { - let len = self.vio_encap(slot, frame); + let header = self.vio_tx_offload(frame)?; + let len = self.vio_encap(slot, frame, &header); self.virtio_enqueue_reserve(slot); self.virtio_enqueue(slot, len, true); self.virtio_enqueue_commit(slot); @@ -322,10 +384,13 @@ impl Virtq { if self.virtio_start_vq_intr() { self.vio_tx_dequeue(); } + + Ok(()) } } /// Packet header structure +#[derive(Default)] #[repr(C, packed)] struct VirtioNetHdr { flags: u8, @@ -337,13 +402,7 @@ struct VirtioNetHdr { num_buffers: u16, // only present if VIRTIO_NET_F_MRG_RXBUF is negotiated } -const _VIRTIO_NET_HDR_F_NEEDS_CSUM: u8 = 1; -const _VIRTIO_NET_HDR_F_DATA_VALID: u8 = 2; -const _VIRTIO_NET_HDR_GSO_NONE: u8 = 0; -const _VIRTIO_NET_HDR_GSO_TCPV4: u8 = 1; -const _VIRTIO_NET_HDR_GSO_UDP: u8 = 3; -const _VIRTIO_NET_HDR_GSO_TCPV6: u8 = 4; -const _VIRTIO_NET_HDR_GSO_ECN: u8 = 0x80; +const VIRTIO_NET_HDR_F_NEEDS_CSUM: u8 = 1; pub fn match_device(vendor: u16, id: u16) -> bool { vendor == pcie_id::VIRTIO_VENDOR_ID && id == VIRTIO_NET_ID @@ -456,6 +515,7 @@ impl VirtioNetInner { self.driver_features |= VIRTIO_NET_F_MAC; self.driver_features |= VIRTIO_NET_F_STATUS; self.driver_features |= VIRTIO_NET_F_SPEED_DUPLEX; + self.driver_features |= VIRTIO_NET_F_CSUM; self.virtio_pci_negotiate_features()?; @@ -468,6 +528,15 @@ impl VirtioNetInner { self.capabilities = NetCapabilities::empty(); self.flags = NetFlags::BROADCAST | NetFlags::SIMPLEX | NetFlags::MULTICAST; + if self.virtio_has_feature(VIRTIO_NET_F_CSUM) { + self.capabilities |= NetCapabilities::CSUM_UDPv4; + + // NOTE: we currently only support UDPv4 + // self.capabilities |= NetCapabilities::CSUM_TCPv4; + // self.capabilities |= NetCapabilities::CSUM_TCPv6; + // self.capabilities |= NetCapabilities::CSUM_UDPv6; + } + let num_queues = 1; // TODO: support multiple queues for i in 0..num_queues { let mut rx = self.virtio_alloc_vq(2 * i)?; @@ -955,7 +1024,7 @@ impl NetDevice for VirtioNet { let inner = self.inner.read(); let mut node = MCSNode::new(); let mut tx = inner.virtqueues[que_id].tx.lock(&mut node); - tx.vio_start(&data); + tx.vio_start(&data).or(Err(NetDevError::DeviceError))?; } let tx_vq_index = (que_id * 2 + 1) as u16;