diff --git a/pkg/nvlib/device/api.go b/pkg/nvlib/device/api.go index c2a6517..1b823e3 100644 --- a/pkg/nvlib/device/api.go +++ b/pkg/nvlib/device/api.go @@ -27,8 +27,10 @@ type Interface interface { GetMigDevices() ([]MigDevice, error) GetMigProfiles() ([]MigProfile, error) NewDevice(d nvml.Device) (Device, error) + NewDeviceByIdentifier(Identifier) (Device, error) NewDeviceByUUID(uuid string) (Device, error) NewMigDevice(d nvml.Device) (MigDevice, error) + NewMigDeviceByIdentifier(Identifier) (MigDevice, error) NewMigDeviceByUUID(uuid string) (MigDevice, error) NewMigProfile(giProfileID, ciProfileID, ciEngProfileID int, migMemorySizeMB, deviceMemorySizeBytes uint64) (MigProfile, error) ParseMigProfile(profile string) (MigProfile, error) diff --git a/pkg/nvlib/device/device.go b/pkg/nvlib/device/device.go index 5e1510c..d866e21 100644 --- a/pkg/nvlib/device/device.go +++ b/pkg/nvlib/device/device.go @@ -18,6 +18,7 @@ package device import ( "fmt" + "strconv" "github.com/NVIDIA/go-nvml/pkg/nvml" ) @@ -49,6 +50,31 @@ func (d *devicelib) NewDevice(dev nvml.Device) (Device, error) { return d.newDevice(dev) } +// NewDeviceByIdentifier builds a new device from a device identifier. +func (d *devicelib) NewDeviceByIdentifier(id Identifier) (Device, error) { + switch { + case id.IsGpuUUID(): + return d.NewDeviceByUUID(string(id)) + case id.IsGpuIndex(): + idx, err := strconv.Atoi(string(id)) + if err != nil { + return nil, fmt.Errorf("failed to convert device index to an int: %w", err) + } + return d.NewDeviceByIndex(idx) + default: + return nil, fmt.Errorf("invalid device identifier: %v", id) + } +} + +// NewDeviceByIndex builds a new Device for the specified index. +func (d *devicelib) NewDeviceByIndex(index int) (Device, error) { + dev, ret := d.nvmllib.DeviceGetHandleByIndex(index) + if ret != nvml.SUCCESS { + return nil, fmt.Errorf("error getting device handle for index '%v': %v", index, ret) + } + return d.newDevice(dev) +} + // NewDeviceByUUID builds a new Device from a UUID. func (d *devicelib) NewDeviceByUUID(uuid string) (Device, error) { dev, ret := d.nvmllib.DeviceGetHandleByUUID(uuid) diff --git a/pkg/nvlib/device/mig_device.go b/pkg/nvlib/device/mig_device.go index 7145a06..cccc94c 100644 --- a/pkg/nvlib/device/mig_device.go +++ b/pkg/nvlib/device/mig_device.go @@ -18,6 +18,8 @@ package device import ( "fmt" + "strconv" + "strings" "github.com/NVIDIA/go-nvml/pkg/nvml" ) @@ -45,6 +47,42 @@ func (d *devicelib) NewMigDevice(handle nvml.Device) (MigDevice, error) { if !isMig { return nil, fmt.Errorf("not a MIG device") } + return d.newMigDevice(handle) +} + +// NewMigDeviceByIdentifier builds a new MigDevice for the specified identifier. +// If the identifier is not a valid MIG identifier, an error is raised. +func (d *devicelib) NewMigDeviceByIdentifier(id Identifier) (MigDevice, error) { + switch { + case id.IsMigUUID(): + return d.NewMigDeviceByUUID(string(id)) + case id.IsMigIndex(): + split := strings.SplitN(string(id), ":", 2) + gpuIdx, err := strconv.Atoi(split[0]) + if err != nil { + return nil, fmt.Errorf("failed to convert device index to an int: %w", err) + } + migIdx, err := strconv.Atoi(split[1]) + if err != nil { + return nil, fmt.Errorf("failed to convert device index to an int: %w", err) + } + parent, err := d.NewDeviceByIndex(gpuIdx) + if err != nil { + return nil, fmt.Errorf("failed to get parent device handle: %w", err) + } + migDevice, ret := parent.GetMigDeviceHandleByIndex(migIdx) + if ret != nvml.SUCCESS { + return nil, fmt.Errorf("failed to get mig device by index: %w", ret) + } + return d.newMigDevice(migDevice) + default: + return nil, fmt.Errorf("invalid MIG device identifier: %v", id) + } +} + +// newMigDevice constructs a new MigDevice for the supplied handle. +// The handle is not checked for validity. +func (d *devicelib) newMigDevice(handle nvml.Device) (MigDevice, error) { return &migdevice{handle, d, nil}, nil }