Skip to content

Commit 4eb4ca6

Browse files
committed
Handle multiple GPUs in CDI spec generation from CSV
This change allows CDI specs to be generated for multiple devices when using CSV mode. This can be used in cases where a Tegra-based system consists of an iGPU and dGPU. This behavior can be opted out of using the disable-multiple-csv-devices feature flag. This can be specified by adding the --feaure-flags=disable-multiple-csv-devices command line option to the nvidia-ctk cdi generate command or to the automatic CDI spec generation by adding NVIDIA_CTK_CDI_GENERATE_FEATURE_FLAGS=disable-multiple-csv-devices to the /etc/nvidia-container-toolkit/nvidia-cdi-refresh.env file. Signed-off-by: Evan Lezar <elezar@nvidia.com>
1 parent 39644f9 commit 4eb4ca6

File tree

5 files changed

+282
-39
lines changed

5 files changed

+282
-39
lines changed

pkg/nvcdi/api.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,8 @@ const (
8888
// FeatureEnableCoherentAnnotations enables the addition of annotations
8989
// coherent or non-coherent devices.
9090
FeatureEnableCoherentAnnotations = FeatureFlag("enable-coherent-annotations")
91+
92+
// FeatureDisableMultipleCSVDevices disables the handling of multiple devices
93+
// in CSV mode.
94+
FeatureDisableMultipleCSVDevices = FeatureFlag("disable-multiple-csv-devices")
9195
)

pkg/nvcdi/common-nvml.go

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,16 +25,7 @@ import (
2525
// newCommonNVMLDiscoverer returns a discoverer for entities that are not associated with a specific CDI device.
2626
// This includes driver libraries and meta devices, for example.
2727
func (l *nvmllib) newCommonNVMLDiscoverer() (discover.Discover, error) {
28-
metaDevices := discover.NewCharDeviceDiscoverer(
29-
l.logger,
30-
l.devRoot,
31-
[]string{
32-
"/dev/nvidia-modeset",
33-
"/dev/nvidia-uvm-tools",
34-
"/dev/nvidia-uvm",
35-
"/dev/nvidiactl",
36-
},
37-
)
28+
metaDevices := l.controlDeviceNodeDiscoverer()
3829

3930
graphicsMounts, err := discover.NewGraphicsMountsDiscoverer(l.logger, l.driver, l.hookCreator)
4031
if err != nil {
@@ -54,3 +45,16 @@ func (l *nvmllib) newCommonNVMLDiscoverer() (discover.Discover, error) {
5445

5546
return d, nil
5647
}
48+
49+
func (l *nvmllib) controlDeviceNodeDiscoverer() discover.Discover {
50+
return discover.NewCharDeviceDiscoverer(
51+
l.logger,
52+
l.devRoot,
53+
[]string{
54+
"/dev/nvidia-modeset",
55+
"/dev/nvidia-uvm-tools",
56+
"/dev/nvidia-uvm",
57+
"/dev/nvidiactl",
58+
},
59+
)
60+
}

pkg/nvcdi/full-gpu-nvml.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ type fullGPUDeviceSpecGenerator struct {
3737
uuid string
3838
index int
3939

40-
featureFlags map[FeatureFlag]bool
40+
featureFlags map[FeatureFlag]bool
41+
additionalDiscoverers []discover.Discover
4142
}
4243

4344
var _ DeviceSpecGenerator = (*fullGPUDeviceSpecGenerator)(nil)
@@ -145,7 +146,6 @@ func (l *fullGPUDeviceSpecGenerator) getDeviceEdits() (*cdi.ContainerEdits, erro
145146
if err != nil {
146147
return nil, fmt.Errorf("failed to create device discoverer: %v", err)
147148
}
148-
149149
editsForDevice, err := edits.FromDiscoverer(deviceDiscoverer)
150150
if err != nil {
151151
return nil, fmt.Errorf("failed to create container edits for device: %v", err)
@@ -177,10 +177,18 @@ func (l *fullGPUDeviceSpecGenerator) newFullGPUDiscoverer(d device.Device) (disc
177177
deviceNodes,
178178
)
179179

180-
dd := discover.Merge(
180+
var discoverers []discover.Discover
181+
182+
discoverers = append(discoverers,
181183
deviceNodes,
182184
deviceFolderPermissionHooks,
183185
)
184186

187+
discoverers = append(discoverers, l.additionalDiscoverers...)
188+
189+
dd := discover.Merge(
190+
discoverers...,
191+
)
192+
185193
return dd, nil
186194
}

pkg/nvcdi/lib-csv.go

Lines changed: 253 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,49 @@ package nvcdi
1818

1919
import (
2020
"fmt"
21+
"slices"
22+
"strconv"
2123

2224
"tags.cncf.io/container-device-interface/pkg/cdi"
2325
"tags.cncf.io/container-device-interface/specs-go"
2426

27+
"github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
28+
"github.com/NVIDIA/go-nvlib/pkg/nvlib/info"
29+
"github.com/NVIDIA/go-nvml/pkg/nvml"
30+
"github.com/google/uuid"
31+
2532
"github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2633
"github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
2734
"github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra"
2835
)
2936

3037
type csvlib nvcdilib
3138

39+
type mixedcsvlib nvcdilib
40+
3241
var _ deviceSpecGeneratorFactory = (*csvlib)(nil)
3342

43+
// DeviceSpecGenerators creates a set of generators for the specified set of
44+
// devices.
45+
// If NVML is not available or the disable-multiple-csv-devices feature flag is
46+
// enabled, a single device is assumed.
3447
func (l *csvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) {
48+
if l.featureFlags[FeatureDisableMultipleCSVDevices] {
49+
return l.purecsvDeviceSpecGenerators(ids...)
50+
}
51+
hasNVML, _ := l.infolib.HasNvml()
52+
if !hasNVML {
53+
return l.purecsvDeviceSpecGenerators(ids...)
54+
}
55+
mixed, err := l.mixedDeviceSpecGenerators(ids...)
56+
if err != nil {
57+
l.logger.Warningf("Failed to create mixed CSV spec generator; falling back to pure CSV implementation: %v", err)
58+
return l.purecsvDeviceSpecGenerators(ids...)
59+
}
60+
return mixed, nil
61+
}
62+
63+
func (l *csvlib) purecsvDeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) {
3564
for _, id := range ids {
3665
switch id {
3766
case "all":
@@ -40,35 +69,42 @@ func (l *csvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error
4069
return nil, fmt.Errorf("unsupported device id: %v", id)
4170
}
4271
}
72+
g := &csvDeviceGenerator{
73+
csvlib: l,
74+
index: 0,
75+
uuid: "",
76+
}
77+
return g, nil
78+
}
79+
80+
func (l *csvlib) mixedDeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) {
81+
return (*mixedcsvlib)(l).DeviceSpecGenerators(ids...)
82+
}
4383

44-
return l, nil
84+
// A csvDeviceGenerator generates CDI specs for a device based on a set of
85+
// platform-specific CSV files.
86+
type csvDeviceGenerator struct {
87+
*csvlib
88+
index int
89+
uuid string
90+
}
91+
92+
func (l *csvDeviceGenerator) GetUUID() (string, error) {
93+
return l.uuid, nil
4594
}
4695

4796
// GetDeviceSpecs returns the CDI device specs for a single device.
48-
func (l *csvlib) GetDeviceSpecs() ([]specs.Device, error) {
49-
d, err := tegra.New(
50-
tegra.WithLogger(l.logger),
51-
tegra.WithDriverRoot(l.driverRoot),
52-
tegra.WithDevRoot(l.devRoot),
53-
tegra.WithHookCreator(l.hookCreator),
54-
tegra.WithLdconfigPath(l.ldconfigPath),
55-
tegra.WithLibrarySearchPaths(l.librarySearchPaths...),
56-
tegra.WithMountSpecs(
57-
tegra.Transform(
58-
tegra.MountSpecsFromCSVFiles(l.logger, l.csvFiles...),
59-
tegra.IgnoreSymlinkMountSpecsByPattern(l.csvIgnorePatterns...),
60-
),
61-
),
62-
)
97+
func (l *csvDeviceGenerator) GetDeviceSpecs() ([]specs.Device, error) {
98+
deviceNodeDiscoverer, err := l.deviceNodeDiscoverer()
6399
if err != nil {
64-
return nil, fmt.Errorf("failed to create discoverer for CSV files: %v", err)
100+
return nil, fmt.Errorf("failed to create discoverer for device nodes from CSV files: %w", err)
65101
}
66-
e, err := edits.FromDiscoverer(d)
102+
e, err := edits.FromDiscoverer(deviceNodeDiscoverer)
67103
if err != nil {
68104
return nil, fmt.Errorf("failed to create container edits for CSV files: %v", err)
69105
}
70106

71-
names, err := l.deviceNamers.GetDeviceNames(0, uuidIgnored{})
107+
names, err := l.deviceNamers.GetDeviceNames(l.index, l)
72108
if err != nil {
73109
return nil, fmt.Errorf("failed to get device name: %v", err)
74110
}
@@ -84,7 +120,204 @@ func (l *csvlib) GetDeviceSpecs() ([]specs.Device, error) {
84120
return deviceSpecs, nil
85121
}
86122

123+
// deviceNodeDiscoverer creates a discoverer for the device nodes associated
124+
// with the specified device.
125+
// The CSV mount specs are used as the source for which device nodes are
126+
// required with the following additions:
127+
//
128+
// - Any regular device nodes (i.e. /dev/nvidia[0-9]+) are removed from the
129+
// input set.
130+
// - The device node (i.e. /dev/nvidia{{ .index }}) associated with this
131+
// particular device is added to the set of device nodes to be discovered.
132+
func (l *csvDeviceGenerator) deviceNodeDiscoverer() (discover.Discover, error) {
133+
mountSpecs := tegra.Transform(
134+
tegra.Transform(
135+
tegra.MountSpecsFromCSVFiles(l.logger, l.csvFiles...),
136+
// We remove non-device nodes.
137+
tegra.OnlyDeviceNodes(),
138+
),
139+
// We remove the regular (nvidia[0-9]+) device nodes.
140+
tegra.WithoutRegularDeviceNodes(),
141+
)
142+
return tegra.New(
143+
tegra.WithLogger(l.logger),
144+
tegra.WithDriverRoot(l.driverRoot),
145+
tegra.WithDevRoot(l.devRoot),
146+
tegra.WithHookCreator(l.hookCreator),
147+
tegra.WithLdconfigPath(l.ldconfigPath),
148+
tegra.WithLibrarySearchPaths(l.librarySearchPaths...),
149+
tegra.WithMountSpecs(
150+
mountSpecs,
151+
// We add the specific device node for this device.
152+
tegra.DeviceNodes(fmt.Sprintf("/dev/nvidia%d", l.index)),
153+
),
154+
)
155+
}
156+
87157
// GetCommonEdits generates a CDI specification that can be used for ANY devices
158+
// These explicitly do not include any device nodes.
88159
func (l *csvlib) GetCommonEdits() (*cdi.ContainerEdits, error) {
89-
return edits.FromDiscoverer(discover.None{})
160+
mountSpecs := tegra.Transform(
161+
tegra.Transform(
162+
tegra.MountSpecsFromCSVFiles(l.logger, l.csvFiles...),
163+
tegra.WithoutDeviceNodes(),
164+
),
165+
tegra.IgnoreSymlinkMountSpecsByPattern(l.csvIgnorePatterns...),
166+
)
167+
driverDiscoverer, err := tegra.New(
168+
tegra.WithLogger(l.logger),
169+
tegra.WithDriverRoot(l.driverRoot),
170+
tegra.WithDevRoot(l.devRoot),
171+
tegra.WithHookCreator(l.hookCreator),
172+
tegra.WithLdconfigPath(l.ldconfigPath),
173+
tegra.WithLibrarySearchPaths(l.librarySearchPaths...),
174+
tegra.WithMountSpecs(mountSpecs),
175+
)
176+
if err != nil {
177+
return nil, fmt.Errorf("failed to create driver discoverer from CSV files: %w", err)
178+
}
179+
return edits.FromDiscoverer(driverDiscoverer)
180+
}
181+
182+
func (l *mixedcsvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error) {
183+
asNvmlLib := (*nvmllib)(l)
184+
err := asNvmlLib.init()
185+
if err != nil {
186+
return nil, fmt.Errorf("failed to initialize nvml: %w", err)
187+
}
188+
defer asNvmlLib.tryShutdown()
189+
190+
if slices.Contains(ids, "all") {
191+
ids, err = l.getAllDeviceIndices()
192+
if err != nil {
193+
return nil, fmt.Errorf("failed to get device indices: %w", err)
194+
}
195+
}
196+
197+
var DeviceSpecGenerators DeviceSpecGenerators
198+
for _, id := range ids {
199+
generator, err := l.deviceSpecGeneratorForId(device.Identifier(id))
200+
if err != nil {
201+
return nil, fmt.Errorf("failed to create device spec generator for device %q: %w", id, err)
202+
}
203+
DeviceSpecGenerators = append(DeviceSpecGenerators, generator)
204+
}
205+
206+
return DeviceSpecGenerators, nil
207+
}
208+
209+
func (l *mixedcsvlib) getAllDeviceIndices() ([]string, error) {
210+
numDevices, ret := l.nvmllib.DeviceGetCount()
211+
if ret != nvml.SUCCESS {
212+
return nil, fmt.Errorf("faled to get device count: %v", ret)
213+
}
214+
215+
var allIndices []string
216+
for index := range numDevices {
217+
allIndices = append(allIndices, fmt.Sprintf("%d", index))
218+
}
219+
return allIndices, nil
220+
}
221+
222+
func (l *mixedcsvlib) deviceSpecGeneratorForId(id device.Identifier) (DeviceSpecGenerator, error) {
223+
switch {
224+
case id.IsGpuUUID(), isIntegratedGPUID(id):
225+
uuid := string(id)
226+
device, ret := l.nvmllib.DeviceGetHandleByUUID(uuid)
227+
if ret != nvml.SUCCESS {
228+
return nil, fmt.Errorf("failed to get device handle from UUID %q: %v", uuid, ret)
229+
}
230+
index, ret := device.GetIndex()
231+
if ret != nvml.SUCCESS {
232+
return nil, fmt.Errorf("failed to get device index: %v", ret)
233+
}
234+
return l.csvDeviceSpecGenerator(index, uuid, device)
235+
case id.IsGpuIndex():
236+
index, err := strconv.Atoi(string(id))
237+
if err != nil {
238+
return nil, fmt.Errorf("failed to convert device index to an int: %w", err)
239+
}
240+
device, ret := l.nvmllib.DeviceGetHandleByIndex(index)
241+
if ret != nvml.SUCCESS {
242+
return nil, fmt.Errorf("failed to get device handle from index: %v", ret)
243+
}
244+
uuid, ret := device.GetUUID()
245+
if ret != nvml.SUCCESS {
246+
return nil, fmt.Errorf("failed to get UUID: %v", ret)
247+
}
248+
return l.csvDeviceSpecGenerator(index, uuid, device)
249+
case id.IsMigUUID():
250+
fallthrough
251+
case id.IsMigIndex():
252+
return nil, fmt.Errorf("generating a CDI spec for MIG id %q is not supported in CSV mode", id)
253+
}
254+
return nil, fmt.Errorf("identifier is not a valid UUID or index: %q", id)
255+
}
256+
257+
func (l *mixedcsvlib) csvDeviceSpecGenerator(index int, uuid string, device nvml.Device) (DeviceSpecGenerator, error) {
258+
isIntegrated, err := isIntegratedGPU(device)
259+
if err != nil {
260+
return nil, fmt.Errorf("is-integrated check failed for device (index=%v,uuid=%v)", index, uuid)
261+
}
262+
263+
g := &csvDeviceGenerator{
264+
csvlib: (*csvlib)(l),
265+
index: index,
266+
uuid: uuid,
267+
}
268+
269+
if !isIntegrated {
270+
csvDeviceNodeDiscoverer, err := g.deviceNodeDiscoverer()
271+
if err != nil {
272+
return nil, fmt.Errorf("failed to create discoverer for devices nodes: %w", err)
273+
}
274+
275+
// If this is not an integrated GPU, we also create a spec generator for
276+
// the full GPU.
277+
dgpu := (*nvmllib)(l).withInit(&fullGPUDeviceSpecGenerator{
278+
nvmllib: (*nvmllib)(l),
279+
uuid: uuid,
280+
index: index,
281+
// For the CSV case, we include the control device nodes at a
282+
// device level.
283+
additionalDiscoverers: []discover.Discover{
284+
(*nvmllib)(l).controlDeviceNodeDiscoverer(),
285+
csvDeviceNodeDiscoverer,
286+
},
287+
featureFlags: l.featureFlags,
288+
})
289+
return dgpu, nil
290+
}
291+
292+
return g, nil
293+
}
294+
295+
func isIntegratedGPUID(id device.Identifier) bool {
296+
_, err := uuid.Parse(string(id))
297+
return err == nil
298+
}
299+
300+
// isIntegratedGPU checks whether the specified device is an integrated GPU.
301+
// As a proxy we check the PCI Bus if for thes
302+
// TODO: This should be replaced by an explicit NVML call once available.
303+
func isIntegratedGPU(d nvml.Device) (bool, error) {
304+
pciInfo, ret := d.GetPciInfo()
305+
if ret == nvml.ERROR_NOT_SUPPORTED {
306+
name, ret := d.GetName()
307+
if ret != nvml.SUCCESS {
308+
return false, fmt.Errorf("failed to get device name: %v", ret)
309+
}
310+
return info.IsIntegratedGPUName(name), nil
311+
}
312+
if ret != nvml.SUCCESS {
313+
return false, fmt.Errorf("failed to get PCI info: %v", ret)
314+
}
315+
316+
if pciInfo.Domain != 0 {
317+
return false, nil
318+
}
319+
if pciInfo.Bus != 1 {
320+
return false, nil
321+
}
322+
return pciInfo.Device == 0, nil
90323
}

0 commit comments

Comments
 (0)