diff --git a/pkg/nvmdev/nvmdev.go b/pkg/nvmdev/nvmdev.go index 24ec228..37acac0 100644 --- a/pkg/nvmdev/nvmdev.go +++ b/pkg/nvmdev/nvmdev.go @@ -56,6 +56,7 @@ type ParentDevice struct { type Device struct { Path string UUID string + MDEVName string MDEVType string Driver string IommuGroup int @@ -141,6 +142,11 @@ func NewDevice(root string, uuid string) (*Device, error) { return nil, nil } + mdevName, err := m.name() + if err != nil { + return nil, fmt.Errorf("error geting mdev name: %v", err) + } + mdevType, err := m.Type() if err != nil { return nil, fmt.Errorf("error getting mdev type: %v", err) @@ -159,6 +165,7 @@ func NewDevice(root string, uuid string) (*Device, error) { device := Device{ Path: path, UUID: uuid, + MDEVName: mdevName, MDEVType: mdevType, Driver: driver, IommuGroup: iommuGroup, @@ -198,24 +205,34 @@ func (m mdev) parentDevicePath() string { return path.Dir(string(m)) } -func (m mdev) Type() (string, error) { +func (m mdev) name() (string, error) { mdevTypeDir, err := m.resolve("mdev_type") if err != nil { return "", err } - mdevType, err := os.ReadFile(path.Join(mdevTypeDir, "name")) + mdevName, err := os.ReadFile(path.Join(mdevTypeDir, "name")) if err != nil { return "", fmt.Errorf("unable to read mdev_type name for mdev %s: %v", m, err) } - // file in the format: [NVIDIA|GRID] - mdevTypeStr := strings.TrimSpace(string(mdevType)) - mdevTypeSplit := strings.SplitN(mdevTypeStr, " ", 2) - if len(mdevTypeSplit) != 2 { - return "", fmt.Errorf("unable to parse mdev_type name %s for mdev %s", mdevTypeStr, m) + mdevNameStr := strings.TrimSpace(string(mdevName)) + + return mdevNameStr, nil +} + +func (m mdev) Type() (string, error) { + mdevName, err := m.name() + if err != nil { + return "", fmt.Errorf("error getting the mdev_type name: %v", err) + } + + // mdevName is in the format: [NVIDIA|GRID] + mdevNameSplit := strings.SplitN(mdevName, " ", 2) + if len(mdevNameSplit) != 2 { + return "", fmt.Errorf("unable to parse mdev_type name '%s' for mdev %s", mdevName, m) } - return mdevTypeSplit[1], nil + return mdevNameSplit[1], nil } func (m mdev) driver() (string, error) { diff --git a/pkg/nvmdev/nvmdev_test.go b/pkg/nvmdev/nvmdev_test.go index 517c895..de0ca6a 100644 --- a/pkg/nvmdev/nvmdev_test.go +++ b/pkg/nvmdev/nvmdev_test.go @@ -17,8 +17,9 @@ package nvmdev import ( - "github.com/stretchr/testify/require" "testing" + + "github.com/stretchr/testify/require" ) func TestNvmdev(t *testing.T) { @@ -54,7 +55,8 @@ func TestNvmdev(t *testing.T) { mdevA100 := mdevs[0] - require.Equal(t, "A100-4C", mdevA100.MDEVType, "Wrong value for mdev_type") + require.Equal(t, "NVIDIA A100-4C", mdevA100.MDEVName, "Wrong value for mdev name") + require.Equal(t, "A100-4C", mdevA100.MDEVType, "Wrong value for mdev type") require.Equal(t, "vfio_mdev", mdevA100.Driver, "Wrong driver detected for mdev device") require.Equal(t, 200, mdevA100.IommuGroup, "Wrong value for iommu_group")