Skip to content

Commit db3c350

Browse files
committed
Add support for a platform override file
This change adds support for reading the detected platform (if set to `auto`) from a platform override file. This allows system administrators to explicitly select a detected platform for tooling such as the nvidia-container-toolkit, the k8s-device-plugin, and k8s-dra-driver-gpu. Signed-off-by: Evan Lezar <elezar@nvidia.com>
1 parent d0f42ba commit db3c350

File tree

3 files changed

+76
-4
lines changed

3 files changed

+76
-4
lines changed

pkg/nvlib/info/builder.go

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,6 @@ func New(opts ...Option) Interface {
5757
if o.devicelib == nil {
5858
o.devicelib = device.New(o.nvmllib)
5959
}
60-
if o.platform == "" {
61-
o.platform = PlatformAuto
62-
}
6360
if o.propertyExtractor == nil {
6461
o.propertyExtractor = &propertyExtractor{
6562
root: o.root,
@@ -70,9 +67,25 @@ func New(opts ...Option) Interface {
7067
return &infolib{
7168
PlatformResolver: &platformResolver{
7269
logger: o.logger,
73-
platform: o.platform,
70+
platform: o.normalizePlatform(),
7471
propertyExtractor: o.propertyExtractor,
7572
},
7673
PropertyExtractor: o.propertyExtractor,
7774
}
7875
}
76+
77+
func (o options) normalizePlatform() Platform {
78+
if o.platform != "" && o.platform != PlatformAuto {
79+
return o.platform
80+
}
81+
82+
override, reason := getPlaformOverride()
83+
if override != "" {
84+
o.logger.Debugf("Using platform-override %q", override)
85+
return Platform(override)
86+
}
87+
88+
o.logger.Debugf("No platform-override detected: %v", reason)
89+
90+
return PlatformAuto
91+
}

pkg/nvlib/info/resolver.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@
1616

1717
package info
1818

19+
import (
20+
"bufio"
21+
"fmt"
22+
"os"
23+
"strings"
24+
)
25+
1926
// Platform represents a supported plaform.
2027
type Platform string
2128

@@ -62,3 +69,35 @@ func (p platformResolver) ResolvePlatform() Platform {
6269
return PlatformUnknown
6370
}
6471
}
72+
73+
// getPlatformOverride checks the system for a platform override file.
74+
// This allows system administrators to force the detection of a specific
75+
// platform.
76+
//
77+
// The first non-empty and non-comment line (starting with #) in the file is
78+
// returned.
79+
//
80+
// Note that no checks are performed for a valid platform value.
81+
//
82+
// This function can be overridden for testing purposes.
83+
var getPlaformOverride = func() (string, string) {
84+
platformOverrideFile, err := os.Open("/etc/nvidia-container-toolkit/platform-override")
85+
if os.IsNotExist(err) {
86+
return "", "platform-override file does not exist"
87+
}
88+
if err != nil {
89+
return "", fmt.Errorf("failed to open platform-override file: %w", err).Error()
90+
}
91+
defer platformOverrideFile.Close()
92+
93+
scanner := bufio.NewScanner(platformOverrideFile)
94+
for scanner.Scan() {
95+
line := strings.TrimSpace(scanner.Text())
96+
if line == "" || strings.HasPrefix(line, "#") {
97+
continue
98+
}
99+
return line, "read from platform-override file"
100+
}
101+
102+
return "", "empty platform-override file"
103+
}

pkg/nvlib/info/resolver_test.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import (
2626
func TestResolvePlatform(t *testing.T) {
2727
testCases := []struct {
2828
platform string
29+
platformOverride string
2930
hasTegraFiles bool
3031
hasDXCore bool
3132
hasNVML bool
@@ -82,10 +83,16 @@ func TestResolvePlatform(t *testing.T) {
8283
hasDXCore: true,
8384
expected: "not-auto",
8485
},
86+
{
87+
platform: "auto",
88+
platformOverride: "overridden",
89+
expected: "overridden",
90+
},
8591
}
8692

8793
for i, tc := range testCases {
8894
t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) {
95+
defer setGetPlatformOverrideForTest(tc.platformOverride)()
8996
l := New(
9097
WithPropertyExtractor(&PropertyExtractorMock{
9198
HasDXCoreFunc: func() (bool, string) {
@@ -108,3 +115,16 @@ func TestResolvePlatform(t *testing.T) {
108115
})
109116
}
110117
}
118+
119+
// setGetPlatformOverrideForTest overrides the distribution IDs that would normally be read from the /etc/os-release file.
120+
func setGetPlatformOverrideForTest(override string) func() {
121+
original := getPlaformOverride
122+
123+
getPlaformOverride = func() (string, string) {
124+
return override, "overridden for test"
125+
}
126+
127+
return func() {
128+
getPlaformOverride = original
129+
}
130+
}

0 commit comments

Comments
 (0)