Skip to content

Commit 181ac42

Browse files
committed
Allow device IDs to be specified when generating CDI specs
Signed-off-by: Evan Lezar <elezar@nvidia.com>
1 parent 07f71f1 commit 181ac42

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

cmd/nvidia-ctk/cdi/generate/generate.go

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ type options struct {
7373
ignorePatterns []string
7474
}
7575

76+
deviceIDs []string
77+
7678
// the following are used for dependency injection during spec generation.
7779
nvmllib nvml.Interface
7880
}
@@ -232,6 +234,14 @@ func (m command) build() *cli.Command {
232234
Destination: &opts.featureFlags,
233235
Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_FEATURE_FLAGS"),
234236
},
237+
&cli.StringSliceFlag{
238+
Name: "device-id",
239+
Aliases: []string{"device-ids", "device", "devices"},
240+
Usage: "Restrict generation to the specified device identifiers",
241+
Value: []string{"all"},
242+
Destination: &opts.deviceIDs,
243+
Sources: cli.EnvVars("NVIDIA_CTK_CDI_GENERATE_DEVICE_IDS"),
244+
},
235245
},
236246
}
237247

@@ -373,7 +383,7 @@ func (m command) generateSpecs(opts *options) ([]generatedSpecs, error) {
373383
return nil, fmt.Errorf("failed to create CDI library: %v", err)
374384
}
375385

376-
allDeviceSpecs, err := cdilib.GetDeviceSpecsByID("all")
386+
allDeviceSpecs, err := cdilib.GetDeviceSpecsByID(opts.deviceIDs...)
377387
if err != nil {
378388
return nil, fmt.Errorf("failed to create device CDI specs: %v", err)
379389
}

cmd/nvidia-ctk/cdi/generate/generate_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,10 @@ containerEdits:
452452
for _, tc := range testCases {
453453
// Apply overrides for all test cases:
454454
tc.options.nvidiaCDIHookPath = "/usr/bin/nvidia-cdi-hook"
455+
if tc.options.deviceIDs == nil {
456+
tc.options.deviceIDs = []string{"all"}
457+
tc.expectedOptions.deviceIDs = []string{"all"}
458+
}
455459

456460
t.Run(tc.description, func(t *testing.T) {
457461
c := command{

0 commit comments

Comments
 (0)