@@ -28,8 +28,8 @@ use std::env;
2828use std:: fs;
2929use std:: path;
3030
31- pub mod callbacks;
32- pub mod cuda_sdk;
31+ mod callbacks;
32+ mod cuda_sdk;
3333
3434fn main ( ) {
3535 let outdir = path:: PathBuf :: from (
@@ -63,8 +63,9 @@ fn main() {
6363 println ! ( "cargo::rerun-if-env-changed={}" , e) ;
6464 }
6565
66- create_cuda_driver_bindings ( & sdk, outdir. as_path ( ) ) ;
67- create_cuda_runtime_bindings ( & sdk, outdir. as_path ( ) ) ;
66+ create_driver_bindings ( & sdk, outdir. as_path ( ) ) ;
67+ create_runtime_bindings ( & sdk, outdir. as_path ( ) ) ;
68+ create_runtime_types_bindings ( & sdk, outdir. as_path ( ) ) ;
6869 create_cublas_bindings ( & sdk, outdir. as_path ( ) ) ;
6970 create_nptx_compiler_bindings ( & sdk, outdir. as_path ( ) ) ;
7071 create_nvvm_bindings ( & sdk, outdir. as_path ( ) ) ;
@@ -73,8 +74,8 @@ fn main() {
7374 feature = "driver" ,
7475 feature = "runtime" ,
7576 feature = "cublas" ,
76- feature = "cublaslt " ,
77- feature = "cublasxt "
77+ feature = "cublasLt " ,
78+ feature = "cublasXt "
7879 ) ) {
7980 for libdir in sdk. cuda_library_paths ( ) {
8081 println ! ( "cargo::rustc-link-search=native={}" , libdir. display( ) ) ;
@@ -84,11 +85,11 @@ fn main() {
8485 if cfg ! ( feature = "runtime" ) {
8586 println ! ( "cargo::rustc-link-lib=dylib=cudart" ) ;
8687 }
87- if cfg ! ( feature = "cublas" ) || cfg ! ( feature = "cublasxt " ) {
88+ if cfg ! ( feature = "cublas" ) || cfg ! ( feature = "cublasXt " ) {
8889 println ! ( "cargo::rustc-link-lib=dylib=cublas" ) ;
8990 }
90- if cfg ! ( feature = "cublaslt " ) {
91- println ! ( "cargo::rustc-link-lib=dylib=cublaslt " ) ;
91+ if cfg ! ( feature = "cublasLt " ) {
92+ println ! ( "cargo::rustc-link-lib=dylib=cublasLt " ) ;
9293 }
9394 if cfg ! ( feature = "nvvm" ) {
9495 for libdir in sdk. nvvm_library_paths ( ) {
@@ -101,7 +102,53 @@ fn main() {
101102 }
102103}
103104
104- fn create_cuda_driver_bindings ( sdk : & cuda_sdk:: CudaSdk , outdir : & path:: Path ) {
105+ fn create_runtime_types_bindings ( sdk : & cuda_sdk:: CudaSdk , outdir : & path:: Path ) {
106+ let params = & [
107+ ( cfg ! ( feature = "driver_types" ) , "driver_types" ) ,
108+ ( cfg ! ( feature = "library_types" ) , "library_types" ) ,
109+ ( cfg ! ( feature = "vector_types" ) , "vector_types" ) ,
110+ ( cfg ! ( feature = "texture_types" ) , "texture_types" ) ,
111+ ( cfg ! ( feature = "surface_types" ) , "surface_types" ) ,
112+ ( cfg ! ( feature = "cuComplex" ) , "cuComplex" ) ,
113+ ] ;
114+ for ( should_generate, pkg) in params {
115+ if !should_generate {
116+ continue ;
117+ }
118+ let bindgen_path = path:: PathBuf :: from ( format ! ( "{}/{}_sys.rs" , outdir. display( ) , pkg) ) ;
119+ let header = sdk
120+ . cuda_root ( )
121+ . join ( format ! ( "include/{}.h" , pkg) )
122+ . display ( )
123+ . to_string ( ) ;
124+ let bindings = bindgen:: Builder :: default ( )
125+ . header ( & header)
126+ . parse_callbacks ( Box :: new ( bindgen:: CargoCallbacks :: new ( ) ) )
127+ . clang_args (
128+ sdk. cuda_include_paths ( )
129+ . iter ( )
130+ . map ( |p| format ! ( "-I{}" , p. display( ) ) ) ,
131+ )
132+ . allowlist_file ( format ! ( r".*{pkg}\.h" ) )
133+ . allowlist_recursively ( false )
134+ . default_enum_style ( bindgen:: EnumVariation :: Rust {
135+ non_exhaustive : false ,
136+ } )
137+ . derive_default ( true )
138+ . derive_eq ( true )
139+ . derive_hash ( true )
140+ . derive_ord ( true )
141+ . size_t_is_usize ( true )
142+ . layout_tests ( true )
143+ . generate ( )
144+ . unwrap_or_else ( |e| panic ! ( "Unable to generate {pkg} bindings: {e}" ) ) ;
145+ bindings
146+ . write_to_file ( bindgen_path. as_path ( ) )
147+ . unwrap_or_else ( |e| panic ! ( "Cannot write {pkg} bindgen output to file: {e}" ) ) ;
148+ }
149+ }
150+
151+ fn create_driver_bindings ( sdk : & cuda_sdk:: CudaSdk , outdir : & path:: Path ) {
105152 if !cfg ! ( feature = "driver" ) {
106153 return ;
107154 }
@@ -121,13 +168,7 @@ fn create_cuda_driver_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
121168 . iter ( )
122169 . map ( |p| format ! ( "-I{}" , p. display( ) ) ) ,
123170 )
124- . allowlist_type ( "^CU.*" )
125- . allowlist_type ( "^cuuint(32|64)_t" )
126- . allowlist_type ( "^cudaError_enum" )
127- . allowlist_type ( "^cu.*Complex$" )
128- . allowlist_type ( "^cuda.*" )
129- . allowlist_var ( "^CU.*" )
130- . allowlist_function ( "^cu.*" )
171+ . allowlist_file ( r".*cuda[^/\\]*\.h" )
131172 . default_enum_style ( bindgen:: EnumVariation :: Rust {
132173 non_exhaustive : false ,
133174 } )
@@ -145,7 +186,7 @@ fn create_cuda_driver_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
145186 . expect ( "Cannot write CUDA driver bindgen output to file." ) ;
146187}
147188
148- fn create_cuda_runtime_bindings ( sdk : & cuda_sdk:: CudaSdk , outdir : & path:: Path ) {
189+ fn create_runtime_bindings ( sdk : & cuda_sdk:: CudaSdk , outdir : & path:: Path ) {
149190 if !cfg ! ( feature = "runtime" ) {
150191 return ;
151192 }
@@ -165,14 +206,13 @@ fn create_cuda_runtime_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
165206 . iter ( )
166207 . map ( |p| format ! ( "-I{}" , p. display( ) ) ) ,
167208 )
168- . allowlist_type ( "^CU.*" )
169- . allowlist_type ( "^cuda.*" )
170- . allowlist_type ( "^libraryPropertyType.*" )
171- . allowlist_var ( "^CU.*" )
172- . allowlist_function ( "^cu.*" )
209+ . allowlist_file ( r".*cuda[^/\\]*\.h" )
210+ . allowlist_file ( r".*cuComplex\.h" )
211+ . allowlist_recursively ( false )
173212 . default_enum_style ( bindgen:: EnumVariation :: Rust {
174213 non_exhaustive : false ,
175214 } )
215+ . disable_nested_struct_naming ( )
176216 . derive_default ( true )
177217 . derive_eq ( true )
178218 . derive_hash ( true )
@@ -188,19 +228,51 @@ fn create_cuda_runtime_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
188228}
189229
190230fn create_cublas_bindings ( sdk : & cuda_sdk:: CudaSdk , outdir : & path:: Path ) {
191- #[ rustfmt:: skip]
192231 let params = & [
193- ( cfg ! ( feature = "cublas" ) , "cublas" , "^cublas.*" , "^CUBLAS.*" ) ,
194- ( cfg ! ( feature = "cublaslt" ) , "cublasLt" , "^cublasLt.*" , "^CUBLASLT.*" ) ,
195- ( cfg ! ( feature = "cublasxt" ) , "cublasXt" , "^cublasXt.*" , "^CUBLASXT.*" ) ,
232+ (
233+ cfg ! ( feature = "cublas" ) ,
234+ "cublas" ,
235+ vec ! [ r".*cublas(_api|_v2)\.h" ] ,
236+ vec ! [
237+ r".*cuComplex\.h" ,
238+ r".*driver_types\.h" ,
239+ r".*library_types\.h" ,
240+ r".*vector_types\.h" ,
241+ ] ,
242+ ) ,
243+ (
244+ cfg ! ( feature = "cublasLt" ) ,
245+ "cublasLt" ,
246+ vec ! [ r".*cublasLt\.h" ] ,
247+ vec ! [
248+ r".*cublas(_api|_v2)*\.h" ,
249+ r".*cuComplex\.h" ,
250+ r".*driver_types\.h" ,
251+ r".*library_types\.h" ,
252+ r".*vector_types\.h" ,
253+ r".*std\w+\.h" ,
254+ ] ,
255+ ) ,
256+ (
257+ cfg ! ( feature = "cublasXt" ) ,
258+ "cublasXt" ,
259+ vec ! [ r".*cublasXt\.h" ] ,
260+ vec ! [
261+ r".*cublas(_api|_v2)*\.h" ,
262+ r".*cuComplex\.h" ,
263+ r".*driver_types\.h" ,
264+ r".*library_types\.h" ,
265+ r".*vector_types\.h" ,
266+ ] ,
267+ ) ,
196268 ] ;
197- for ( should_generate, pkg, tf , var ) in params {
269+ for ( should_generate, pkg, allowed , blocked ) in params {
198270 if !should_generate {
199271 continue ;
200272 }
201273 let bindgen_path = path:: PathBuf :: from ( format ! ( "{}/{pkg}_sys.rs" , outdir. display( ) ) ) ;
202274 let header = format ! ( "build/{pkg}_wrapper.h" ) ;
203- let bindings = bindgen:: Builder :: default ( )
275+ let mut bindings = bindgen:: Builder :: default ( )
204276 . header ( & header)
205277 . parse_callbacks ( Box :: new ( callbacks:: FunctionRenames :: new (
206278 pkg,
@@ -214,9 +286,16 @@ fn create_cublas_bindings(sdk: &cuda_sdk::CudaSdk, outdir: &path::Path) {
214286 . iter ( )
215287 . map ( |p| format ! ( "-I{}" , p. display( ) ) ) ,
216288 )
217- . allowlist_type ( tf)
218- . allowlist_function ( tf)
219- . allowlist_var ( var)
289+ . allowlist_recursively ( false ) ;
290+
291+ for file in allowed {
292+ bindings = bindings. allowlist_file ( file) ;
293+ }
294+ for file in blocked {
295+ bindings = bindings. blocklist_file ( file) ;
296+ }
297+
298+ let bindings = bindings
220299 . default_enum_style ( bindgen:: EnumVariation :: Rust {
221300 non_exhaustive : false ,
222301 } )
0 commit comments