diff --git a/pkg/nvlib/info/builder.go b/pkg/nvlib/info/builder.go index 6168440..e0de54a 100644 --- a/pkg/nvlib/info/builder.go +++ b/pkg/nvlib/info/builder.go @@ -57,9 +57,6 @@ func New(opts ...Option) Interface { if o.devicelib == nil { o.devicelib = device.New(o.nvmllib) } - if o.platform == "" { - o.platform = PlatformAuto - } if o.propertyExtractor == nil { o.propertyExtractor = &propertyExtractor{ root: o.root, @@ -70,9 +67,25 @@ func New(opts ...Option) Interface { return &infolib{ PlatformResolver: &platformResolver{ logger: o.logger, - platform: o.platform, + platform: o.normalizePlatform(), propertyExtractor: o.propertyExtractor, }, PropertyExtractor: o.propertyExtractor, } } + +func (o options) normalizePlatform() Platform { + if o.platform != "" && o.platform != PlatformAuto { + return o.platform + } + + override, reason := getPlaformOverride() + if override != "" { + o.logger.Debugf("Using platform-override %q", override) + return Platform(override) + } + + o.logger.Debugf("No platform-override detected: %v", reason) + + return PlatformAuto +} diff --git a/pkg/nvlib/info/resolver.go b/pkg/nvlib/info/resolver.go index 8243738..80d0cd8 100644 --- a/pkg/nvlib/info/resolver.go +++ b/pkg/nvlib/info/resolver.go @@ -16,6 +16,13 @@ package info +import ( + "bufio" + "fmt" + "os" + "strings" +) + // Platform represents a supported plaform. type Platform string @@ -62,3 +69,35 @@ func (p platformResolver) ResolvePlatform() Platform { return PlatformUnknown } } + +// getPlatformOverride checks the system for a platform override file. +// This allows system administrators to force the detection of a specific +// platform. +// +// The first non-empty and non-comment line (starting with #) in the file is +// returned. +// +// Note that no checks are performed for a valid platform value. +// +// This function can be overridden for testing purposes. +var getPlaformOverride = func() (string, string) { + platformOverrideFile, err := os.Open("/etc/nvidia-container-toolkit/platform-override") + if os.IsNotExist(err) { + return "", "platform-override file does not exist" + } + if err != nil { + return "", fmt.Errorf("failed to open platform-override file: %w", err).Error() + } + defer platformOverrideFile.Close() + + scanner := bufio.NewScanner(platformOverrideFile) + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + return line, "read from platform-override file" + } + + return "", "empty platform-override file" +} diff --git a/pkg/nvlib/info/resolver_test.go b/pkg/nvlib/info/resolver_test.go index 357c8f9..d4b6e13 100644 --- a/pkg/nvlib/info/resolver_test.go +++ b/pkg/nvlib/info/resolver_test.go @@ -26,6 +26,7 @@ import ( func TestResolvePlatform(t *testing.T) { testCases := []struct { platform string + platformOverride string hasTegraFiles bool hasDXCore bool hasNVML bool @@ -82,10 +83,16 @@ func TestResolvePlatform(t *testing.T) { hasDXCore: true, expected: "not-auto", }, + { + platform: "auto", + platformOverride: "overridden", + expected: "overridden", + }, } for i, tc := range testCases { t.Run(fmt.Sprintf("test case %d", i), func(t *testing.T) { + defer setGetPlatformOverrideForTest(tc.platformOverride)() l := New( WithPropertyExtractor(&PropertyExtractorMock{ HasDXCoreFunc: func() (bool, string) { @@ -108,3 +115,16 @@ func TestResolvePlatform(t *testing.T) { }) } } + +// setGetPlatformOverrideForTest overrides the distribution IDs that would normally be read from the /etc/os-release file. +func setGetPlatformOverrideForTest(override string) func() { + original := getPlaformOverride + + getPlaformOverride = func() (string, string) { + return override, "overridden for test" + } + + return func() { + getPlaformOverride = original + } +}