diff --git a/vfio-ioctls/src/lib.rs b/vfio-ioctls/src/lib.rs index 8098896..8b06705 100644 --- a/vfio-ioctls/src/lib.rs +++ b/vfio-ioctls/src/lib.rs @@ -66,9 +66,9 @@ mod vfio_device; mod vfio_ioctls; pub use vfio_device::{ - VfioContainer, VfioDevice, VfioDeviceFd, VfioGroup, VfioIrq, VfioRegion, VfioRegionInfoCap, - VfioRegionInfoCapNvlink2Lnkspd, VfioRegionInfoCapNvlink2Ssatgt, VfioRegionInfoCapSparseMmap, - VfioRegionInfoCapType, VfioRegionSparseMmapArea, + VfioContainer, VfioDevice, VfioDeviceFd, VfioGroup, VfioIrq, VfioOps, VfioRegion, + VfioRegionInfoCap, VfioRegionInfoCapNvlink2Lnkspd, VfioRegionInfoCapNvlink2Ssatgt, + VfioRegionInfoCapSparseMmap, VfioRegionInfoCapType, VfioRegionSparseMmapArea, }; /// Error codes for VFIO operations. @@ -137,6 +137,8 @@ pub enum VfioError { GetHostAddress, #[error("invalid dma unmap size")] InvalidDmaUnmapSize, + #[error("failed to downcast VfioOps")] + DowncastVfioOps, } /// Specialized version of `Result` for VFIO subsystem. diff --git a/vfio-ioctls/src/vfio_device.rs b/vfio-ioctls/src/vfio_device.rs index 12221f2..298065c 100644 --- a/vfio-ioctls/src/vfio_device.rs +++ b/vfio-ioctls/src/vfio_device.rs @@ -3,15 +3,14 @@ // // SPDX-License-Identifier: Apache-2.0 OR BSD-3-Clause +use std::any::Any; use std::collections::HashMap; use std::ffi::CString; use std::fs::{File, OpenOptions}; use std::mem::{self, ManuallyDrop}; use std::os::unix::io::{AsRawFd, RawFd}; use std::os::unix::prelude::FileExt; -use std::path::Path; -#[cfg(not(test))] -use std::path::PathBuf; +use std::path::{Path, PathBuf}; use std::sync::{Arc, Mutex}; use byteorder::{ByteOrder, NativeEndian}; @@ -25,7 +24,7 @@ use crate::vfio_ioctls::*; use crate::{Result, VfioError}; #[cfg(all(feature = "kvm", not(test)))] use kvm_bindings::{ - kvm_device_attr, KVM_DEV_VFIO_GROUP, KVM_DEV_VFIO_GROUP_ADD, KVM_DEV_VFIO_GROUP_DEL, + kvm_device_attr, KVM_DEV_VFIO_FILE, KVM_DEV_VFIO_FILE_ADD, KVM_DEV_VFIO_FILE_DEL, }; #[cfg(all(feature = "kvm", not(test)))] use kvm_ioctls::DeviceFd as KvmDeviceFd; @@ -148,6 +147,90 @@ impl vfio_region_info_with_cap { region_with_cap } } +/// Trait to define common operations exposed to user-space drivers for +/// VFIO device wrappers that are either backed by a legacy VfioContainer or +/// a VFIO cdev device using iommufd. +pub trait VfioOps: Any + Send + Sync { + /// Map a region of user space memory (e.g. guest memory) into an IO + /// address space managed by IOMMU hardware to enable DMA for + /// associated VFIO devices + /// + /// # Parameters + /// * iova: IO virtual address to map the memory. + /// * size: size of the memory region. + /// * user_addr: user space address (e.g. host virtual address) for + /// the guest memory region to map. + fn vfio_dma_map(&self, _iova: u64, _size: u64, _user_addr: u64) -> Result<()> { + unimplemented!() + } + + /// Unmap a region of user space memory (e.g. guest memory) from an IO + /// address space managed by IOMMU hardware to disable DMA for + /// associated VFIO devices + /// + /// # Parameters + /// * iova: IO virtual address to unmap the memory. + /// * size: size of the memory region. + fn vfio_dma_unmap(&self, _iova: u64, _size: u64) -> Result<()> { + unimplemented!() + } + + /// Downcast to the underlying vfio wrapper type + fn as_any(&self) -> &dyn Any { + unimplemented!() + } +} + +struct VfioCommon { + #[allow(dead_code)] + device_fd: Option, +} + +impl VfioCommon { + #[cfg(all(any(feature = "kvm", feature = "mshv"), not(test)))] + fn device_set_fd(&self, dev_fd: RawFd, add: bool) -> Result<()> { + let dev_fd_ptr = &dev_fd as *const i32; + + if let Some(device_fd) = self.device_fd.as_ref() { + match &device_fd.0 { + #[cfg(feature = "kvm")] + DeviceFdInner::Kvm(fd) => { + let flag = if add { + KVM_DEV_VFIO_FILE_ADD + } else { + KVM_DEV_VFIO_FILE_DEL + }; + let dev_attr = kvm_device_attr { + flags: 0, + group: KVM_DEV_VFIO_FILE, + attr: u64::from(flag), + addr: dev_fd_ptr as u64, + }; + fd.set_device_attr(&dev_attr) + .map_err(|e| VfioError::SetDeviceAttr(Error::new(e.errno()))) + } + #[cfg(feature = "mshv")] + DeviceFdInner::Mshv(fd) => { + let flag = if add { + MSHV_DEV_VFIO_FILE_ADD + } else { + MSHV_DEV_VFIO_FILE_DEL + }; + let dev_attr = mshv_device_attr { + flags: 0, + group: MSHV_DEV_VFIO_FILE, + attr: u64::from(flag), + addr: dev_fd_ptr as u64, + }; + fd.set_device_attr(&dev_attr) + .map_err(|e| VfioError::SetDeviceAttr(Error::new(e.errno()))) + } + } + } else { + Ok(()) + } + } +} /// A safe wrapper over a VFIO container object. /// @@ -161,9 +244,9 @@ impl vfio_region_info_with_cap { /// address translation mapping tables. pub struct VfioContainer { pub(crate) container: File, - #[allow(dead_code)] - pub(crate) device_fd: Option, pub(crate) groups: Mutex>>, + #[allow(dead_code)] + common: VfioCommon, } impl VfioContainer { @@ -180,7 +263,7 @@ impl VfioContainer { let container = VfioContainer { container, - device_fd, + common: VfioCommon { device_fd }, groups: Mutex::new(HashMap::new()), }; container.check_api_version()?; @@ -256,10 +339,9 @@ impl VfioContainer { // Clean up the group when the last user releases reference to the group, three reference // count for: - // - one reference held by the last device object // - one reference cloned in VfioDevice.drop() and passed into here // - one reference held by the groups hashmap - if Arc::strong_count(&group) == 3 { + if Arc::strong_count(&group) == 2 { #[cfg(any(feature = "kvm", feature = "mshv"))] match self.device_del_group(&group) { Ok(_) => {} @@ -276,7 +358,9 @@ impl VfioContainer { } } - /// Map a region of guest memory regions into the vfio container's iommu table. + /// Map a region of user space memory (e.g. guest memory) into an IO + /// address space managed by IOMMU hardware to enable DMA for + /// associated VFIO devices /// /// # Parameters /// * iova: IO virtual address to mapping the memory. @@ -294,10 +378,12 @@ impl VfioContainer { vfio_syscall::map_dma(self, &dma_map) } - /// Unmap a region of guest memory regions into the vfio container's iommu table. + /// Unmap a region of user space memory (e.g. guest memory) from an IO + /// address space managed by IOMMU hardware to disable DMA for + /// associated VFIO devices /// /// # Parameters - /// * iova: IO virtual address to mapping the memory. + /// * iova: IO virtual address to unmap the memory. /// * size: size of the memory region. pub fn vfio_dma_unmap(&self, iova: u64, size: u64) -> Result<()> { let mut dma_unmap = vfio_iommu_type1_dma_unmap { @@ -346,50 +432,6 @@ impl VfioContainer { }) } - #[cfg(all(any(feature = "kvm", feature = "mshv"), not(test)))] - fn device_set_group(&self, group: &VfioGroup, add: bool) -> Result<()> { - let group_fd_ptr = &group.as_raw_fd() as *const i32; - - if let Some(device_fd) = self.device_fd.as_ref() { - match &device_fd.0 { - #[cfg(feature = "kvm")] - DeviceFdInner::Kvm(fd) => { - let flag = if add { - KVM_DEV_VFIO_GROUP_ADD - } else { - KVM_DEV_VFIO_GROUP_DEL - }; - let dev_attr = kvm_device_attr { - flags: 0, - group: KVM_DEV_VFIO_GROUP, - attr: u64::from(flag), - addr: group_fd_ptr as u64, - }; - fd.set_device_attr(&dev_attr) - .map_err(|e| VfioError::SetDeviceAttr(Error::new(e.errno()))) - } - #[cfg(feature = "mshv")] - DeviceFdInner::Mshv(fd) => { - let flag = if add { - MSHV_DEV_VFIO_FILE_ADD - } else { - MSHV_DEV_VFIO_FILE_DEL - }; - let dev_attr = mshv_device_attr { - flags: 0, - group: MSHV_DEV_VFIO_FILE, - attr: u64::from(flag), - addr: group_fd_ptr as u64, - }; - fd.set_device_attr(&dev_attr) - .map_err(|e| VfioError::SetDeviceAttr(Error::new(e.errno()))) - } - } - } else { - Ok(()) - } - } - /// Add a device to a VFIO group /// /// The VFIO device fd should have been set. @@ -398,7 +440,7 @@ impl VfioContainer { /// * group: target VFIO group #[cfg(all(any(feature = "kvm", feature = "mshv"), not(test)))] fn device_add_group(&self, group: &VfioGroup) -> Result<()> { - self.device_set_group(group, true) + self.common.device_set_fd(group.as_raw_fd(), true) } /// Delete a device from a VFIO group @@ -409,7 +451,7 @@ impl VfioContainer { /// * group: target VFIO group #[cfg(all(any(feature = "kvm", feature = "mshv"), not(test)))] fn device_del_group(&self, group: &VfioGroup) -> Result<()> { - self.device_set_group(group, false) + self.common.device_set_fd(group.as_raw_fd(), false) } #[cfg(test)] @@ -429,6 +471,20 @@ impl AsRawFd for VfioContainer { } } +impl VfioOps for VfioContainer { + fn vfio_dma_map(&self, iova: u64, size: u64, user_addr: u64) -> Result<()> { + self.vfio_dma_map(iova, size, user_addr) + } + + fn vfio_dma_unmap(&self, iova: u64, size: u64) -> Result<()> { + self.vfio_dma_unmap(iova, size) + } + + fn as_any(&self) -> &dyn Any { + self + } +} + /// A safe wrapper over a VFIO group object. /// /// The Linux VFIO frameworks supports multiple devices per group, and multiple groups per @@ -473,50 +529,12 @@ impl VfioGroup { self.id } - #[inline] - /// Get device type from device_info flags. - /// - /// # Parameters - /// * `flags`: flags field in device_info structure. - fn get_device_type(flags: &u32) -> u32 { - // There may be more types of device here later according to vfio_bindings. - let device_type: u32 = VFIO_DEVICE_FLAGS_PCI - | VFIO_DEVICE_FLAGS_PLATFORM - | VFIO_DEVICE_FLAGS_AMBA - | VFIO_DEVICE_FLAGS_CCW - | VFIO_DEVICE_FLAGS_AP; - - flags & device_type - } - fn get_device(&self, name: &Path) -> Result { let uuid_osstr = name.file_name().ok_or(VfioError::InvalidPath)?; let uuid_str = uuid_osstr.to_str().ok_or(VfioError::InvalidPath)?; let path: CString = CString::new(uuid_str.as_bytes()).expect("CString::new() failed"); let device = vfio_syscall::get_group_device_fd(self, &path)?; - - let mut dev_info = vfio_device_info { - argsz: mem::size_of::() as u32, - flags: 0, - num_regions: 0, - num_irqs: 0, - cap_offset: 0, - pad: 0, - }; - vfio_syscall::get_device_info(&device, &mut dev_info)?; - match VfioGroup::get_device_type(&dev_info.flags) { - VFIO_DEVICE_FLAGS_PLATFORM => {} - VFIO_DEVICE_FLAGS_PCI => { - if dev_info.num_regions < VFIO_PCI_CONFIG_REGION_INDEX + 1 - || dev_info.num_irqs < VFIO_PCI_MSIX_IRQ_INDEX + 1 - { - return Err(VfioError::VfioDeviceGetInfoPCI); - } - } - _ => { - return Err(VfioError::VfioDeviceGetInfoOther); - } - } + let dev_info = VfioDeviceInfo::get_device_info(&device)?; Ok(VfioDeviceInfo::new(device, &dev_info)) } @@ -610,6 +628,49 @@ pub(crate) struct VfioDeviceInfo { } impl VfioDeviceInfo { + #[inline] + /// Get device type from device_info flags. + /// + /// # Parameters + /// * `flags`: flags field in device_info structure. + fn get_device_type(flags: &u32) -> u32 { + // There may be more types of device here later according to vfio_bindings. + let device_type: u32 = VFIO_DEVICE_FLAGS_PCI + | VFIO_DEVICE_FLAGS_PLATFORM + | VFIO_DEVICE_FLAGS_AMBA + | VFIO_DEVICE_FLAGS_CCW + | VFIO_DEVICE_FLAGS_AP; + + flags & device_type + } + + fn get_device_info(device: &File) -> Result { + let mut dev_info = vfio_device_info { + argsz: mem::size_of::() as u32, + flags: 0, + num_regions: 0, + num_irqs: 0, + cap_offset: 0, + pad: 0, + }; + vfio_syscall::get_device_info(device, &mut dev_info)?; + match VfioDeviceInfo::get_device_type(&dev_info.flags) { + VFIO_DEVICE_FLAGS_PLATFORM => {} + VFIO_DEVICE_FLAGS_PCI => { + if dev_info.num_regions < VFIO_PCI_CONFIG_REGION_INDEX + 1 + || dev_info.num_irqs < VFIO_PCI_MSIX_IRQ_INDEX + 1 + { + return Err(VfioError::VfioDeviceGetInfoPCI); + } + } + _ => { + return Err(VfioError::VfioDeviceGetInfoOther); + } + } + + Ok(dev_info) + } + fn new(device: File, dev_info: &vfio_device_info) -> Self { VfioDeviceInfo { device, @@ -827,8 +888,8 @@ pub struct VfioDevice { pub(crate) flags: u32, pub(crate) regions: Vec, pub(crate) irqs: HashMap, - pub(crate) group: Arc, - pub(crate) container: Arc, + pub(crate) sysfspath: PathBuf, + pub(crate) vfio_ops: Arc, } impl VfioDevice { @@ -846,11 +907,17 @@ impl VfioDevice { /// /// # Parameters /// * `sysfspath`: specify the vfio device path in sys file system. - /// * `container`: the new VFIO device object will bind to this container object. - pub fn new(sysfspath: &Path, container: Arc) -> Result { - let group_id = Self::get_group_id_from_path(sysfspath)?; - let group = container.get_group(group_id)?; - let device_info = group.get_device(sysfspath)?; + /// * `vfio_ops`: the vfio device wrapper object that the new VFIO device object will bind to. + pub fn new(sysfspath: &Path, vfio_ops: Arc) -> Result { + let device_info = + if let Some(vfio_container) = vfio_ops.as_any().downcast_ref::() { + let group_id = Self::get_group_id_from_path(sysfspath)?; + let group = vfio_container.get_group(group_id)?; + group.get_device(sysfspath)? + } else { + return Err(VfioError::DowncastVfioOps); + }; + let regions = device_info.get_regions()?; let irqs = device_info.get_irqs()?; @@ -859,8 +926,8 @@ impl VfioDevice { flags: device_info.flags, regions, irqs, - group, - container, + sysfspath: sysfspath.to_path_buf(), + vfio_ops, }) } @@ -1214,12 +1281,16 @@ impl Drop for VfioDevice { // ManuallyDrop is needed here because we need to ensure that VfioDevice::device is closed // before dropping VfioDevice::group, otherwise it will cause EBUSY when putting the // group object. + if let Some(container) = self.vfio_ops.as_any().downcast_ref::() { + // SAFETY: we own the File object. + unsafe { + ManuallyDrop::drop(&mut self.device); + } - // SAFETY: we own the File object. - unsafe { - ManuallyDrop::drop(&mut self.device); + let group_id = Self::get_group_id_from_path(&self.sysfspath).unwrap(); + let group = container.get_group(group_id).unwrap(); + container.put_group(group); } - self.container.put_group(self.group.clone()); } } @@ -1328,7 +1399,7 @@ mod tests { VfioContainer { container, - device_fd: None, + common: VfioCommon { device_fd: None }, groups: Mutex::new(HashMap::new()), } } @@ -1359,7 +1430,8 @@ mod tests { container.put_group(group4); assert_eq!(Arc::strong_count(&group), 3); container.put_group(group3); - assert_eq!(Arc::strong_count(&group), 1); + assert_eq!(Arc::strong_count(&group), 2); + container.put_group(group); container.vfio_dma_map(0x1000, 0x1000, 0x8000).unwrap(); container.vfio_dma_map(0x2000, 0x2000, 0x8000).unwrap_err(); @@ -1527,18 +1599,18 @@ mod tests { #[test] fn test_get_device_type() { let flags: u32 = VFIO_DEVICE_FLAGS_PCI; - assert_eq!(flags, VfioGroup::get_device_type(&flags)); + assert_eq!(flags, VfioDeviceInfo::get_device_type(&flags)); let flags: u32 = VFIO_DEVICE_FLAGS_PLATFORM; - assert_eq!(flags, VfioGroup::get_device_type(&flags)); + assert_eq!(flags, VfioDeviceInfo::get_device_type(&flags)); let flags: u32 = VFIO_DEVICE_FLAGS_AMBA; - assert_eq!(flags, VfioGroup::get_device_type(&flags)); + assert_eq!(flags, VfioDeviceInfo::get_device_type(&flags)); let flags: u32 = VFIO_DEVICE_FLAGS_CCW; - assert_eq!(flags, VfioGroup::get_device_type(&flags)); + assert_eq!(flags, VfioDeviceInfo::get_device_type(&flags)); let flags: u32 = VFIO_DEVICE_FLAGS_AP; - assert_eq!(flags, VfioGroup::get_device_type(&flags)); + assert_eq!(flags, VfioDeviceInfo::get_device_type(&flags)); } }