@@ -18,10 +18,14 @@ package nvcdi
1818
1919import (
2020 "fmt"
21+ "strings"
2122
2223 "tags.cncf.io/container-device-interface/pkg/cdi"
2324 "tags.cncf.io/container-device-interface/specs-go"
2425
26+ "github.com/NVIDIA/go-nvlib/pkg/nvlib/device"
27+ "github.com/NVIDIA/go-nvml/pkg/nvml"
28+
2529 "github.com/NVIDIA/nvidia-container-toolkit/internal/discover"
2630 "github.com/NVIDIA/nvidia-container-toolkit/internal/edits"
2731 "github.com/NVIDIA/nvidia-container-toolkit/internal/platform-support/tegra"
@@ -46,7 +50,33 @@ func (l *csvlib) DeviceSpecGenerators(ids ...string) (DeviceSpecGenerator, error
4650
4751// GetDeviceSpecs returns the CDI device specs for a single device.
4852func (l * csvlib ) GetDeviceSpecs () ([]specs.Device , error ) {
49- d , err := tegra .New (
53+ d , err := l .driverDiscoverer ()
54+ if err != nil {
55+ return nil , fmt .Errorf ("failed to create driver discoverer from CSV files: %w" , err )
56+ }
57+ e , err := edits .FromDiscoverer (d )
58+ if err != nil {
59+ return nil , fmt .Errorf ("failed to create container edits for CSV files: %w" , err )
60+ }
61+
62+ names , err := l .deviceNamers .GetDeviceNames (0 , uuidIgnored {})
63+ if err != nil {
64+ return nil , fmt .Errorf ("failed to get device name: %w" , err )
65+ }
66+ var deviceSpecs []specs.Device
67+ for _ , name := range names {
68+ deviceSpec := specs.Device {
69+ Name : name ,
70+ ContainerEdits : * e .ContainerEdits ,
71+ }
72+ deviceSpecs = append (deviceSpecs , deviceSpec )
73+ }
74+
75+ return deviceSpecs , nil
76+ }
77+
78+ func (l * csvlib ) driverDiscoverer () (discover.Discover , error ) {
79+ driverDiscoverer , err := tegra .New (
5080 tegra .WithLogger (l .logger ),
5181 tegra .WithDriverRoot (l .driverRoot ),
5282 tegra .WithDevRoot (l .devRoot ),
@@ -57,27 +87,69 @@ func (l *csvlib) GetDeviceSpecs() ([]specs.Device, error) {
5787 tegra .WithIngorePatterns (l .csvIgnorePatterns ... ),
5888 )
5989 if err != nil {
60- return nil , fmt .Errorf ("failed to create discoverer for CSV files: %v " , err )
90+ return nil , fmt .Errorf ("failed to create discoverer for CSV files: %w " , err )
6191 }
62- e , err := edits .FromDiscoverer (d )
92+
93+ cudaCompatDiscoverer := l .cudaCompatDiscoverer ()
94+
95+ ldcacheUpdateHook , err := discover .NewLDCacheUpdateHook (l .logger , driverDiscoverer , l .hookCreator , l .ldconfigPath )
6396 if err != nil {
64- return nil , fmt .Errorf ("failed to create container edits for CSV files : %v " , err )
97+ return nil , fmt .Errorf ("failed to create ldcache update hook discoverer : %w " , err )
6598 }
6699
67- names , err := l .deviceNamers .GetDeviceNames (0 , uuidIgnored {})
100+ d := discover .Merge (
101+ driverDiscoverer ,
102+ cudaCompatDiscoverer ,
103+ // The ldcacheUpdateHook is added last to ensure that the created symlinks are included
104+ ldcacheUpdateHook ,
105+ )
106+ return d , nil
107+ }
108+
109+ func (l * csvlib ) cudaCompatDiscoverer () discover.Discover {
110+ hasNvml , _ := l .infolib .HasNvml ()
111+ if ! hasNvml {
112+ return nil
113+ }
114+
115+ ret := l .nvmllib .Init ()
116+ if ret != nvml .SUCCESS {
117+ l .logger .Warningf ("Failed to initialize NVML: %v" , ret )
118+ return nil
119+ }
120+ defer func () {
121+ _ = l .nvmllib .Shutdown ()
122+ }()
123+
124+ version , ret := l .nvmllib .SystemGetDriverVersion ()
125+ if ret != nvml .SUCCESS {
126+ l .logger .Warningf ("Failed to get driver version: %v" , ret )
127+ return nil
128+ }
129+
130+ var names []string
131+ err := l .devicelib .VisitDevices (func (i int , d device.Device ) error {
132+ name , ret := d .GetName ()
133+ if ret != nvml .SUCCESS {
134+ return fmt .Errorf ("device %v: %v" , i , ret )
135+ }
136+ names = append (names , name )
137+ return nil
138+ })
68139 if err != nil {
69- return nil , fmt . Errorf ( "failed to get device name : %v" , err )
140+ l . logger . Warningf ( "Failed to get device names : %v" , err )
70141 }
71- var deviceSpecs []specs.Device
142+
143+ var cudaCompatContainerRoot string
72144 for _ , name := range names {
73- deviceSpec := specs.Device {
74- Name : name ,
75- ContainerEdits : * e .ContainerEdits ,
145+ if strings .Contains (name , "Orin (nvgpu)" ) {
146+ // TODO: This should probably be a constant.
147+ cudaCompatContainerRoot = "/usr/local/cuda/compat-orin"
148+ break
76149 }
77- deviceSpecs = append (deviceSpecs , deviceSpec )
78150 }
79151
80- return deviceSpecs , nil
152+ return discover . NewCUDACompatHookDiscoverer ( l . logger , l . hookCreator , version , cudaCompatContainerRoot )
81153}
82154
83155// GetCommonEdits generates a CDI specification that can be used for ANY devices
0 commit comments