@@ -31,7 +31,7 @@ import (
3131)
3232
3333const (
34- cudaCompatPath = "/usr/local/cuda/compat"
34+ defaultCudaCompatPath = "/usr/local/cuda/compat"
3535 // cudaCompatLdsoconfdFilenamePattern specifies the pattern for the filename
3636 // in ld.so.conf.d that includes a reference to the CUDA compat path.
3737 // The 00-compat prefix is chosen to ensure that these libraries have a
@@ -44,7 +44,8 @@ type command struct {
4444}
4545
4646type options struct {
47- hostDriverVersion string
47+ cudaCompatContainerRoot string
48+ hostDriverVersion string
4849 // containerSpec allows the path to the container spec to be specified for
4950 // testing.
5051 containerSpec string
@@ -78,6 +79,12 @@ func (m command) build() *cli.Command {
7879 Usage : "Specify the host driver version. If the CUDA compat libraries detected in the container do not have a higher MAJOR version, the hook is a no-op." ,
7980 Destination : & cfg .hostDriverVersion ,
8081 },
82+ & cli.StringFlag {
83+ Name : "cuda-compat-container-root" ,
84+ Usage : "Specify the folder in which CUDA compat libraries are located in the container" ,
85+ Value : defaultCudaCompatPath ,
86+ Destination : & cfg .cudaCompatContainerRoot ,
87+ },
8188 & cli.StringFlag {
8289 Name : "container-spec" ,
8390 Hidden : true ,
@@ -110,7 +117,7 @@ func (m command) run(_ *cli.Command, cfg *options) error {
110117 return fmt .Errorf ("failed to determined container root: %w" , err )
111118 }
112119
113- containerForwardCompatDir , err := m .getContainerForwardCompatDir (containerRoot (containerRootDir ), cfg .hostDriverVersion )
120+ containerForwardCompatDir , err := m .getContainerForwardCompatDir (containerRoot (containerRootDir ), cfg .cudaCompatContainerRoot , cfg . hostDriverVersion )
114121 if err != nil {
115122 return fmt .Errorf ("failed to get container forward compat directory: %w" , err )
116123 }
@@ -121,17 +128,17 @@ func (m command) run(_ *cli.Command, cfg *options) error {
121128 return m .createLdsoconfdFile (containerRoot (containerRootDir ), cudaCompatLdsoconfdFilenamePattern , containerForwardCompatDir )
122129}
123130
124- func (m command ) getContainerForwardCompatDir (containerRoot containerRoot , hostDriverVersion string ) (string , error ) {
131+ func (m command ) getContainerForwardCompatDir (containerRoot containerRoot , cudaCompatRoot string , hostDriverVersion string ) (string , error ) {
125132 if hostDriverVersion == "" {
126133 m .logger .Debugf ("Host driver version not specified" )
127134 return "" , nil
128135 }
129- if ! containerRoot .hasPath (cudaCompatPath ) {
136+ if ! containerRoot .hasPath (cudaCompatRoot ) {
130137 m .logger .Debugf ("No CUDA forward compatibility libraries directory in container" )
131138 return "" , nil
132139 }
133140
134- libs , err := containerRoot .globFiles (filepath .Join (cudaCompatPath , "libcuda.so.*.*" ))
141+ libs , err := containerRoot .globFiles (filepath .Join (cudaCompatRoot , "libcuda.so.*.*" ))
135142 if err != nil {
136143 m .logger .Warningf ("Failed to find CUDA compat library: %w" , err )
137144 return "" , nil
0 commit comments