@@ -18,20 +18,49 @@ package nvcdi
1818
1919import (
2020 "fmt"
21+ "slices"
22+ "strconv"
2123
2224 "tags.cncf.io/container-device-interface/pkg/cdi"
2325 "tags.cncf.io/container-device-interface/specs-go"
2426
27+ "github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
28+ "github.com/NVIDIA/go-nvlib/pkg/nvlib/info"
29+ "github.com/NVIDIA/go-nvml/pkg/nvml"
30+ "github.com/google/uuid"
31+
2532 "github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2633 "github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
2734 "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra"
2835)
2936
3037type csvlib nvcdilib
3138
39+ type mixedcsvlib nvcdilib
40+
3241var _ deviceSpecGeneratorFactory = (* csvlib )(nil )
3342
43+ // DeviceSpecGenerators creates a set of generators for the specified set of
44+ // devices.
45+ // If NVML is not available or the disable-multiple-csv-devices feature flag is
46+ // enabled, a single device is assumed.
3447func (l * csvlib ) DeviceSpecGenerators (ids ... string ) (DeviceSpecGenerator , error ) {
48+ if l .featureFlags [FeatureDisableMultipleCSVDevices ] {
49+ return l .purecsvDeviceSpecGenerators (ids ... )
50+ }
51+ hasNVML , _ := l .infolib .HasNvml ()
52+ if ! hasNVML {
53+ return l .purecsvDeviceSpecGenerators (ids ... )
54+ }
55+ mixed , err := l .mixedDeviceSpecGenerators (ids ... )
56+ if err != nil {
57+ l .logger .Warningf ("Failed to create mixed CSV spec generator; falling back to pure CSV implementation: %v" , err )
58+ return l .purecsvDeviceSpecGenerators (ids ... )
59+ }
60+ return mixed , nil
61+ }
62+
63+ func (l * csvlib ) purecsvDeviceSpecGenerators (ids ... string ) (DeviceSpecGenerator , error ) {
3564 for _ , id := range ids {
3665 switch id {
3766 case "all" :
@@ -40,35 +69,42 @@ func (l *csvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error
4069 return nil , fmt .Errorf ("unsupported device id: %v" , id )
4170 }
4271 }
72+ g := & csvDeviceGenerator {
73+ csvlib : l ,
74+ index : 0 ,
75+ uuid : "" ,
76+ }
77+ return g , nil
78+ }
79+
80+ func (l * csvlib ) mixedDeviceSpecGenerators (ids ... string ) (DeviceSpecGenerator , error ) {
81+ return (* mixedcsvlib )(l ).DeviceSpecGenerators (ids ... )
82+ }
4383
44- return l , nil
84+ // A csvDeviceGenerator generates CDI specs for a device based on a set of
85+ // platform-specific CSV files.
86+ type csvDeviceGenerator struct {
87+ * csvlib
88+ index int
89+ uuid string
90+ }
91+
92+ func (l * csvDeviceGenerator ) GetUUID () (string , error ) {
93+ return l .uuid , nil
4594}
4695
4796// GetDeviceSpecs returns the CDI device specs for a single device.
48- func (l * csvlib ) GetDeviceSpecs () ([]specs.Device , error ) {
49- d , err := tegra .New (
50- tegra .WithLogger (l .logger ),
51- tegra .WithDriverRoot (l .driverRoot ),
52- tegra .WithDevRoot (l .devRoot ),
53- tegra .WithHookCreator (l .hookCreator ),
54- tegra .WithLdconfigPath (l .ldconfigPath ),
55- tegra .WithLibrarySearchPaths (l .librarySearchPaths ... ),
56- tegra .WithMountSpecs (
57- tegra .Transform (
58- tegra .MountSpecsFromCSVFiles (l .logger , l .csvFiles ... ),
59- tegra .IgnoreSymlinkMountSpecsByPattern (l .csvIgnorePatterns ... ),
60- ),
61- ),
62- )
97+ func (l * csvDeviceGenerator ) GetDeviceSpecs () ([]specs.Device , error ) {
98+ deviceNodeDiscoverer , err := l .deviceNodeDiscoverer ()
6399 if err != nil {
64- return nil , fmt .Errorf ("failed to create discoverer for CSV files: %v " , err )
100+ return nil , fmt .Errorf ("failed to create discoverer for device nodes from CSV files: %w " , err )
65101 }
66- e , err := edits .FromDiscoverer (d )
102+ e , err := edits .FromDiscoverer (deviceNodeDiscoverer )
67103 if err != nil {
68104 return nil , fmt .Errorf ("failed to create container edits for CSV files: %v" , err )
69105 }
70106
71- names , err := l .deviceNamers .GetDeviceNames (0 , uuidIgnored {} )
107+ names , err := l .deviceNamers .GetDeviceNames (l . index , l )
72108 if err != nil {
73109 return nil , fmt .Errorf ("failed to get device name: %v" , err )
74110 }
@@ -84,7 +120,204 @@ func (l *csvlib) GetDeviceSpecs() ([]specs.Device, error) {
84120 return deviceSpecs , nil
85121}
86122
123+ // deviceNodeDiscoverer creates a discoverer for the device nodes associated
124+ // with the specified device.
125+ // The CSV mount specs are used as the source for which device nodes are
126+ // required with the following additions:
127+ //
128+ // - Any regular device nodes (i.e. /dev/nvidia[0-9]+) are removed from the
129+ // input set.
130+ // - The device node (i.e. /dev/nvidia{{ .index }}) associated with this
131+ // particular device is added to the set of device nodes to be discovered.
132+ func (l * csvDeviceGenerator ) deviceNodeDiscoverer () (discover.Discover , error ) {
133+ mountSpecs := tegra .Transform (
134+ tegra .Transform (
135+ tegra .MountSpecsFromCSVFiles (l .logger , l .csvFiles ... ),
136+ // We remove non-device nodes.
137+ tegra .OnlyDeviceNodes (),
138+ ),
139+ // We remove the regular (nvidia[0-9]+) device nodes.
140+ tegra .WithoutRegularDeviceNodes (),
141+ )
142+ return tegra .New (
143+ tegra .WithLogger (l .logger ),
144+ tegra .WithDriverRoot (l .driverRoot ),
145+ tegra .WithDevRoot (l .devRoot ),
146+ tegra .WithHookCreator (l .hookCreator ),
147+ tegra .WithLdconfigPath (l .ldconfigPath ),
148+ tegra .WithLibrarySearchPaths (l .librarySearchPaths ... ),
149+ tegra .WithMountSpecs (
150+ mountSpecs ,
151+ // We add the specific device node for this device.
152+ tegra .DeviceNodes (fmt .Sprintf ("/dev/nvidia%d" , l .index )),
153+ ),
154+ )
155+ }
156+
87157// GetCommonEdits generates a CDI specification that can be used for ANY devices
158+ // These explicitly do not include any device nodes.
88159func (l * csvlib ) GetCommonEdits () (* cdi.ContainerEdits , error ) {
89- return edits .FromDiscoverer (discover.None {})
160+ mountSpecs := tegra .Transform (
161+ tegra .Transform (
162+ tegra .MountSpecsFromCSVFiles (l .logger , l .csvFiles ... ),
163+ tegra .WithoutDeviceNodes (),
164+ ),
165+ tegra .IgnoreSymlinkMountSpecsByPattern (l .csvIgnorePatterns ... ),
166+ )
167+ driverDiscoverer , err := tegra .New (
168+ tegra .WithLogger (l .logger ),
169+ tegra .WithDriverRoot (l .driverRoot ),
170+ tegra .WithDevRoot (l .devRoot ),
171+ tegra .WithHookCreator (l .hookCreator ),
172+ tegra .WithLdconfigPath (l .ldconfigPath ),
173+ tegra .WithLibrarySearchPaths (l .librarySearchPaths ... ),
174+ tegra .WithMountSpecs (mountSpecs ),
175+ )
176+ if err != nil {
177+ return nil , fmt .Errorf ("failed to create driver discoverer from CSV files: %w" , err )
178+ }
179+ return edits .FromDiscoverer (driverDiscoverer )
180+ }
181+
182+ func (l * mixedcsvlib ) DeviceSpecGenerators (ids ... string ) (DeviceSpecGenerator , error ) {
183+ asNvmlLib := (* nvmllib )(l )
184+ err := asNvmlLib .init ()
185+ if err != nil {
186+ return nil , fmt .Errorf ("failed to initialize nvml: %w" , err )
187+ }
188+ defer asNvmlLib .tryShutdown ()
189+
190+ if slices .Contains (ids , "all" ) {
191+ ids , err = l .getAllDeviceIndices ()
192+ if err != nil {
193+ return nil , fmt .Errorf ("failed to get device indices: %w" , err )
194+ }
195+ }
196+
197+ var DeviceSpecGenerators DeviceSpecGenerators
198+ for _ , id := range ids {
199+ generator , err := l .deviceSpecGeneratorForId (device .Identifier (id ))
200+ if err != nil {
201+ return nil , fmt .Errorf ("failed to create device spec generator for device %q: %w" , id , err )
202+ }
203+ DeviceSpecGenerators = append (DeviceSpecGenerators , generator )
204+ }
205+
206+ return DeviceSpecGenerators , nil
207+ }
208+
209+ func (l * mixedcsvlib ) getAllDeviceIndices () ([]string , error ) {
210+ numDevices , ret := l .nvmllib .DeviceGetCount ()
211+ if ret != nvml .SUCCESS {
212+ return nil , fmt .Errorf ("faled to get device count: %v" , ret )
213+ }
214+
215+ var allIndices []string
216+ for index := range numDevices {
217+ allIndices = append (allIndices , fmt .Sprintf ("%d" , index ))
218+ }
219+ return allIndices , nil
220+ }
221+
222+ func (l * mixedcsvlib ) deviceSpecGeneratorForId (id device.Identifier ) (DeviceSpecGenerator , error ) {
223+ switch {
224+ case id .IsGpuUUID (), isIntegratedGPUID (id ):
225+ uuid := string (id )
226+ device , ret := l .nvmllib .DeviceGetHandleByUUID (uuid )
227+ if ret != nvml .SUCCESS {
228+ return nil , fmt .Errorf ("failed to get device handle from UUID %q: %v" , uuid , ret )
229+ }
230+ index , ret := device .GetIndex ()
231+ if ret != nvml .SUCCESS {
232+ return nil , fmt .Errorf ("failed to get device index: %v" , ret )
233+ }
234+ return l .csvDeviceSpecGenerator (index , uuid , device )
235+ case id .IsGpuIndex ():
236+ index , err := strconv .Atoi (string (id ))
237+ if err != nil {
238+ return nil , fmt .Errorf ("failed to convert device index to an int: %w" , err )
239+ }
240+ device , ret := l .nvmllib .DeviceGetHandleByIndex (index )
241+ if ret != nvml .SUCCESS {
242+ return nil , fmt .Errorf ("failed to get device handle from index: %v" , ret )
243+ }
244+ uuid , ret := device .GetUUID ()
245+ if ret != nvml .SUCCESS {
246+ return nil , fmt .Errorf ("failed to get UUID: %v" , ret )
247+ }
248+ return l .csvDeviceSpecGenerator (index , uuid , device )
249+ case id .IsMigUUID ():
250+ fallthrough
251+ case id .IsMigIndex ():
252+ return nil , fmt .Errorf ("generating a CDI spec for MIG id %q is not supported in CSV mode" , id )
253+ }
254+ return nil , fmt .Errorf ("identifier is not a valid UUID or index: %q" , id )
255+ }
256+
257+ func (l * mixedcsvlib ) csvDeviceSpecGenerator (index int , uuid string , device nvml.Device ) (DeviceSpecGenerator , error ) {
258+ isIntegrated , err := isIntegratedGPU (device )
259+ if err != nil {
260+ return nil , fmt .Errorf ("is-integrated check failed for device (index=%v,uuid=%v)" , index , uuid )
261+ }
262+
263+ g := & csvDeviceGenerator {
264+ csvlib : (* csvlib )(l ),
265+ index : index ,
266+ uuid : uuid ,
267+ }
268+
269+ if ! isIntegrated {
270+ csvDeviceNodeDiscoverer , err := g .deviceNodeDiscoverer ()
271+ if err != nil {
272+ return nil , fmt .Errorf ("failed to create discoverer for devices nodes: %w" , err )
273+ }
274+
275+ // If this is not an integrated GPU, we also create a spec generator for
276+ // the full GPU.
277+ dgpu := (* nvmllib )(l ).withInit (& fullGPUDeviceSpecGenerator {
278+ nvmllib : (* nvmllib )(l ),
279+ uuid : uuid ,
280+ index : index ,
281+ // For the CSV case, we include the control device nodes at a
282+ // device level.
283+ additionalDiscoverers : []discover.Discover {
284+ (* nvmllib )(l ).controlDeviceNodeDiscoverer (),
285+ csvDeviceNodeDiscoverer ,
286+ },
287+ featureFlags : l .featureFlags ,
288+ })
289+ return dgpu , nil
290+ }
291+
292+ return g , nil
293+ }
294+
295+ func isIntegratedGPUID (id device.Identifier ) bool {
296+ _ , err := uuid .Parse (string (id ))
297+ return err == nil
298+ }
299+
300+ // isIntegratedGPU checks whether the specified device is an integrated GPU.
301+ // As a proxy we check the PCI Bus if for thes
302+ // TODO: This should be replaced by an explicit NVML call once available.
303+ func isIntegratedGPU (d nvml.Device ) (bool , error ) {
304+ pciInfo , ret := d .GetPciInfo ()
305+ if ret == nvml .ERROR_NOT_SUPPORTED {
306+ name , ret := d .GetName ()
307+ if ret != nvml .SUCCESS {
308+ return false , fmt .Errorf ("failed to get device name: %v" , ret )
309+ }
310+ return info .IsIntegratedGPUName (name ), nil
311+ }
312+ if ret != nvml .SUCCESS {
313+ return false , fmt .Errorf ("failed to get PCI info: %v" , ret )
314+ }
315+
316+ if pciInfo .Domain != 0 {
317+ return false , nil
318+ }
319+ if pciInfo .Bus != 1 {
320+ return false , nil
321+ }
322+ return pciInfo .Device == 0 , nil
90323}
0 commit comments