Skip to content

Commit 6f7e8b0

Browse files
authored
Merge pull request #1512 from elezar/allow-device-ids-on-cdi-generate
Add --device-id flag to nvidia-ctk cdi generate command
2 parents 4372ea6 + 311e549 commit 6f7e8b0

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

cmd/nvidia-ctk/cdi/generate/generate.go

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ type options struct {
7474
}
7575

7676
noAllDevice bool
77+
deviceIDs []string
7778

7879
// the following are used for dependency injection during spec generation.
7980
nvmllib nvml.Interface
@@ -240,6 +241,14 @@ func (m command) build() *cli.Command {
240241
Destination: &opts.noAllDevice,
241242
Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_NO_ALL_DEVICE"),
242243
},
244+
&cli.StringSliceFlag{
245+
Name: "device-id",
246+
Aliases: []string{"device-ids", "device", "devices"},
247+
Usage: "Restrict generation to the specified device identifiers",
248+
Value: []string{"all"},
249+
Destination: &opts.deviceIDs,
250+
Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_DEVICE_IDS"),
251+
},
243252
},
244253
}
245254

@@ -381,7 +390,7 @@ func (m command) generateSpecs(opts *options) ([]generatedSpecs, error) {
381390
return nil, fmt.Errorf("failed to create CDI library: %v", err)
382391
}
383392

384-
allDeviceSpecs, err := cdilib.GetDeviceSpecsByID("all")
393+
allDeviceSpecs, err := cdilib.GetDeviceSpecsByID(opts.deviceIDs...)
385394
if err != nil {
386395
return nil, fmt.Errorf("failed to create device CDI specs: %v", err)
387396
}

cmd/nvidia-ctk/cdi/generate/generate_test.go

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package generate
1818

1919
import (
2020
"bytes"
21+
"fmt"
2122
"path/filepath"
2223
"strings"
2324
"testing"
@@ -47,6 +48,27 @@ func TestGenerateSpec(t *testing.T) {
4748
expectedError error
4849
expectedSpec string
4950
}{
51+
{
52+
description: "invalid device id",
53+
options: options{
54+
format: "yaml",
55+
mode: "nvml",
56+
vendor: "example.com",
57+
class: "device",
58+
deviceIDs: []string{"99"},
59+
driverRoot: driverRoot,
60+
},
61+
expectedOptions: options{
62+
format: "yaml",
63+
mode: "nvml",
64+
vendor: "example.com",
65+
class: "device",
66+
nvidiaCDIHookPath: "/usr/bin/nvidia-cdi-hook",
67+
deviceIDs: []string{"99"},
68+
driverRoot: driverRoot,
69+
},
70+
expectedError: fmt.Errorf("failed to create device CDI specs: failed to construct device spec generators: failed to get device handle from index: ERROR_INVALID_ARGUMENT"),
71+
},
5072
{
5173
description: "default",
5274
options: options{
@@ -452,6 +474,10 @@ containerEdits:
452474
for _, tc := range testCases {
453475
// Apply overrides for all test cases:
454476
tc.options.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook"
477+
if tc.options.deviceIDs == nil {
478+
tc.options.deviceIDs = []string{"all"}
479+
tc.expectedOptions.deviceIDs = []string{"all"}
480+
}
455481

456482
t.Run(tc.description, func(t *testing.T) {
457483
c := command{
@@ -481,7 +507,7 @@ containerEdits:
481507
tc.options.nvmllib = server
482508

483509
specs, err := c.generateSpecs(&tc.options)
484-
require.ErrorIs(t, err, tc.expectedError)
510+
require.EqualValues(t, err, tc.expectedError)
485511

486512
var buf bytes.Buffer
487513
for _, spec := range specs {

0 commit comments

Comments
 (0)