@@ -18,10 +18,14 @@ package nvcdi
1818
1919import (
2020 "fmt"
21+ "strings"
2122
2223 "tags.cncf.io/container-device-interface/pkg/cdi"
2324 "tags.cncf.io/container-device-interface/specs-go"
2425
26+ "github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
27+ "github.com/NVIDIA/go-nvml/pkg/nvml"
28+
2529 "github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2630 "github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
2731 "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra"
@@ -46,7 +50,33 @@ func (l *csvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error
4650
4751// GetDeviceSpecs returns the CDI device specs for a single device.
4852func (l * csvlib ) GetDeviceSpecs () ([]specs.Device , error ) {
49- d , err := tegra .New (
53+ d , err := l .driverDiscoverer ()
54+ if err != nil {
55+ return nil , fmt .Errorf ("failed to create driver discoverer from CSV files: %w" , err )
56+ }
57+ e , err := edits .FromDiscoverer (d )
58+ if err != nil {
59+ return nil , fmt .Errorf ("failed to create container edits for CSV files: %w" , err )
60+ }
61+
62+ names , err := l .deviceNamers .GetDeviceNames (0 , uuidIgnored {})
63+ if err != nil {
64+ return nil , fmt .Errorf ("failed to get device name: %w" , err )
65+ }
66+ var deviceSpecs []specs.Device
67+ for _ , name := range names {
68+ deviceSpec := specs.Device {
69+ Name : name ,
70+ ContainerEdits : * e .ContainerEdits ,
71+ }
72+ deviceSpecs = append (deviceSpecs , deviceSpec )
73+ }
74+
75+ return deviceSpecs , nil
76+ }
77+
78+ func (l * csvlib ) driverDiscoverer () (discover.Discover , error ) {
79+ driverDiscoverer , err := tegra .New (
5080 tegra .WithLogger (l .logger ),
5181 tegra .WithDriverRoot (l .driverRoot ),
5282 tegra .WithDevRoot (l .devRoot ),
@@ -57,27 +87,76 @@ func (l *csvlib) GetDeviceSpecs() ([]specs.Device, error) {
5787 tegra .WithIngorePatterns (l .csvIgnorePatterns ... ),
5888 )
5989 if err != nil {
60- return nil , fmt .Errorf ("failed to create discoverer for CSV files: %v " , err )
90+ return nil , fmt .Errorf ("failed to create discoverer for CSV files: %w " , err )
6191 }
62- e , err := edits .FromDiscoverer (d )
92+
93+ cudaCompatDiscoverer := l .cudaCompatDiscoverer ()
94+
95+ ldcacheUpdateHook , err := discover .NewLDCacheUpdateHook (l .logger , driverDiscoverer , l .hookCreator , l .ldconfigPath )
6396 if err != nil {
64- return nil , fmt .Errorf ("failed to create container edits for CSV files : %v " , err )
97+ return nil , fmt .Errorf ("failed to create ldcache update hook discoverer : %w " , err )
6598 }
6699
67- names , err := l .deviceNamers .GetDeviceNames (0 , uuidIgnored {})
100+ d := discover .Merge (
101+ driverDiscoverer ,
102+ cudaCompatDiscoverer ,
103+ // The ldcacheUpdateHook is added last to ensure that the created symlinks are included
104+ ldcacheUpdateHook ,
105+ )
106+ return d , nil
107+ }
108+
109+ // cudaCompatDiscoverer returns a discoverer for the CUDA forward compat hook
110+ // on Tegra-based systems.
111+ // It the system has NVML available, this is used to determine the driver
112+ // version to be passed to the hook.
113+ // On Orin-based systems, the compat library root in the container is also set.
114+ func (l * csvlib ) cudaCompatDiscoverer () discover.Discover {
115+ hasNvml , _ := l .infolib .HasNvml ()
116+ if ! hasNvml {
117+ return nil
118+ }
119+
120+ ret := l .nvmllib .Init ()
121+ if ret != nvml .SUCCESS {
122+ l .logger .Warningf ("Failed to initialize NVML: %v" , ret )
123+ return nil
124+ }
125+ defer func () {
126+ _ = l .nvmllib .Shutdown ()
127+ }()
128+
129+ version , ret := l .nvmllib .SystemGetDriverVersion ()
130+ if ret != nvml .SUCCESS {
131+ l .logger .Warningf ("Failed to get driver version: %v" , ret )
132+ return nil
133+ }
134+
135+ var names []string
136+ err := l .devicelib .VisitDevices (func (i int , d device.Device ) error {
137+ name , ret := d .GetName ()
138+ if ret != nvml .SUCCESS {
139+ return fmt .Errorf ("device %v: %v" , i , ret )
140+ }
141+ names = append (names , name )
142+ return nil
143+ })
68144 if err != nil {
69- return nil , fmt .Errorf ("failed to get device name: %v" , err )
145+ l .logger .Warningf ("Failed to get device names: %v" , err )
146+ return nil
70147 }
71- var deviceSpecs []specs.Device
148+
149+ var cudaCompatContainerRoot string
72150 for _ , name := range names {
73- deviceSpec := specs.Device {
74- Name : name ,
75- ContainerEdits : * e .ContainerEdits ,
151+ // TODO: Should this be overridable through a feature flag / config option?
152+ if strings .Contains (name , "Orin (nvgpu)" ) {
153+ // TODO: This should probably be a constant or configurable.
154+ cudaCompatContainerRoot = "/usr/local/cuda/compat-orin"
155+ break
76156 }
77- deviceSpecs = append (deviceSpecs , deviceSpec )
78157 }
79158
80- return deviceSpecs , nil
159+ return discover . NewCUDACompatHookDiscoverer ( l . logger , l . hookCreator , version , cudaCompatContainerRoot )
81160}
82161
83162// GetCommonEdits generates a CDI specification that can be used for ANY devices
0 commit comments