@@ -31,7 +31,7 @@ import (
3131)
3232
3333const (
34- cudaCompatPath = "/usr/local/cuda/compat"
34+ defaultCudaCompatPath = "/usr/local/cuda/compat"
3535 // cudaCompatLdsoconfdFilenamePattern specifies the pattern for the filename
3636 // in ld.so.conf.d that includes a reference to the CUDA compat path.
3737 // The 00-compat prefix is chosen to ensure that these libraries have a
@@ -44,8 +44,9 @@ type command struct {
4444}
4545
4646type options struct {
47- hostDriverVersion string
48- containerSpec string
47+ cudaCompatContainerRoot string
48+ hostDriverVersion string
49+ containerSpec string
4950}
5051
5152// NewCommand constructs a cuda-compat command with the specified logger
@@ -76,6 +77,12 @@ func (m command) build() *cli.Command {
7677 Usage : "Specify the host driver version. If the CUDA compat libraries detected in the container do not have a higher MAJOR version, the hook is a no-op." ,
7778 Destination : & cfg .hostDriverVersion ,
7879 },
80+ & cli.StringFlag {
81+ Name : "cuda-compat-container-root" ,
82+ Usage : "Specify the folder in which CUDA compat libraries are located in the container" ,
83+ Value : defaultCudaCompatPath ,
84+ Destination : & cfg .cudaCompatContainerRoot ,
85+ },
7986 & cli.StringFlag {
8087 Name : "container-spec" ,
8188 Hidden : true ,
@@ -108,7 +115,7 @@ func (m command) run(_ *cli.Command, cfg *options) error {
108115 return fmt .Errorf ("failed to determined container root: %w" , err )
109116 }
110117
111- containerForwardCompatDir , err := m .getContainerForwardCompatDir (containerRoot (containerRootDir ), cfg .hostDriverVersion )
118+ containerForwardCompatDir , err := m .getContainerForwardCompatDir (containerRoot (containerRootDir ), cfg .cudaCompatContainerRoot , cfg . hostDriverVersion )
112119 if err != nil {
113120 return fmt .Errorf ("failed to get container forward compat directory: %w" , err )
114121 }
@@ -119,17 +126,17 @@ func (m command) run(_ *cli.Command, cfg *options) error {
119126 return m .createLdsoconfdFile (containerRoot (containerRootDir ), cudaCompatLdsoconfdFilenamePattern , containerForwardCompatDir )
120127}
121128
122- func (m command ) getContainerForwardCompatDir (containerRoot containerRoot , hostDriverVersion string ) (string , error ) {
129+ func (m command ) getContainerForwardCompatDir (containerRoot containerRoot , cudaCompatRoot string , hostDriverVersion string ) (string , error ) {
123130 if hostDriverVersion == "" {
124131 m .logger .Debugf ("Host driver version not specified" )
125132 return "" , nil
126133 }
127- if ! containerRoot .hasPath (cudaCompatPath ) {
134+ if ! containerRoot .hasPath (cudaCompatRoot ) {
128135 m .logger .Debugf ("No CUDA forward compatibility libraries directory in container" )
129136 return "" , nil
130137 }
131138
132- libs , err := containerRoot .globFiles (filepath .Join (cudaCompatPath , "libcuda.so.*.*" ))
139+ libs , err := containerRoot .globFiles (filepath .Join (cudaCompatRoot , "libcuda.so.*.*" ))
133140 if err != nil {
134141 m .logger .Warningf ("Failed to find CUDA compat library: %w" , err )
135142 return "" , nil
0 commit comments