@@ -4,8 +4,8 @@ use std::os::raw::c_char;
44use std:: ptr;
55
66use cust:: stream:: Stream ;
7- use cust_raw:: cublas_sys ;
8- use cust_raw:: driver_sys ;
7+ use cust_raw:: cublas ;
8+ use cust_raw:: driver ;
99
1010use super :: error:: DropResult ;
1111use super :: error:: ToResult as _;
@@ -73,7 +73,7 @@ bitflags::bitflags! {
7373/// - [Matrix Multiplication <span style="float:right;">`gemm`</span>](CublasContext::gemm)
7474#[ derive( Debug ) ]
7575pub struct CublasContext {
76- pub ( crate ) raw : cublas_sys :: cublasHandle_t ,
76+ pub ( crate ) raw : cublas :: cublasHandle_t ,
7777}
7878
7979impl CublasContext {
@@ -92,10 +92,10 @@ impl CublasContext {
9292 pub fn new ( ) -> Result < Self > {
9393 let mut raw = MaybeUninit :: uninit ( ) ;
9494 unsafe {
95- cublas_sys :: cublasCreate ( raw. as_mut_ptr ( ) ) . to_result ( ) ?;
96- cublas_sys :: cublasSetPointerMode (
95+ cublas :: cublasCreate ( raw. as_mut_ptr ( ) ) . to_result ( ) ?;
96+ cublas :: cublasSetPointerMode (
9797 raw. assume_init ( ) ,
98- cublas_sys :: cublasPointerMode_t:: CUBLAS_POINTER_MODE_DEVICE ,
98+ cublas :: cublasPointerMode_t:: CUBLAS_POINTER_MODE_DEVICE ,
9999 )
100100 . to_result ( ) ?;
101101 Ok ( Self {
@@ -112,7 +112,7 @@ impl CublasContext {
112112
113113 unsafe {
114114 let inner = mem:: replace ( & mut ctx. raw , ptr:: null_mut ( ) ) ;
115- match cublas_sys :: cublasDestroy ( inner) . to_result ( ) {
115+ match cublas :: cublasDestroy ( inner) . to_result ( ) {
116116 Ok ( ( ) ) => {
117117 mem:: forget ( ctx) ;
118118 Ok ( ( ) )
@@ -127,7 +127,7 @@ impl CublasContext {
127127 let mut raw = MaybeUninit :: < u32 > :: uninit ( ) ;
128128 unsafe {
129129 // getVersion can't fail
130- cublas_sys :: cublasGetVersion ( self . raw , raw. as_mut_ptr ( ) . cast ( ) )
130+ cublas :: cublasGetVersion ( self . raw , raw. as_mut_ptr ( ) . cast ( ) )
131131 . to_result ( )
132132 . unwrap ( ) ;
133133
@@ -145,17 +145,15 @@ impl CublasContext {
145145 ) -> Result < T > {
146146 unsafe {
147147 // cudaStream_t is the same as CUstream
148- cublas_sys :: cublasSetStream (
148+ cublas :: cublasSetStream (
149149 self . raw ,
150- mem:: transmute :: < * mut driver_sys:: CUstream_st , * mut cublas_sys:: CUstream_st > (
151- stream. as_inner ( ) ,
152- ) ,
150+ mem:: transmute :: < driver:: CUstream , cublas:: cudaStream_t > ( stream. as_inner ( ) ) ,
153151 )
154152 . to_result ( ) ?;
155153 let res = func ( self ) ?;
156154 // reset the stream back to NULL just in case someone calls with_stream, then drops the stream, and tries to
157155 // execute a raw sys function with the context's handle.
158- cublas_sys :: cublasSetStream ( self . raw , ptr:: null_mut ( ) ) . to_result ( ) ?;
156+ cublas :: cublasSetStream ( self . raw , ptr:: null_mut ( ) ) . to_result ( ) ?;
159157 Ok ( res)
160158 }
161159 }
@@ -185,12 +183,12 @@ impl CublasContext {
185183 /// ```
186184 pub fn set_atomics_mode ( & self , allowed : bool ) -> Result < ( ) > {
187185 unsafe {
188- Ok ( cublas_sys :: cublasSetAtomicsMode (
186+ Ok ( cublas :: cublasSetAtomicsMode (
189187 self . raw ,
190188 if allowed {
191- cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED
189+ cublas :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED
192190 } else {
193- cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED
191+ cublas :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED
194192 } ,
195193 )
196194 . to_result ( ) ?)
@@ -215,10 +213,10 @@ impl CublasContext {
215213 pub fn get_atomics_mode ( & self ) -> Result < bool > {
216214 let mut mode = MaybeUninit :: uninit ( ) ;
217215 unsafe {
218- cublas_sys :: cublasGetAtomicsMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
216+ cublas :: cublasGetAtomicsMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
219217 Ok ( match mode. assume_init ( ) {
220- cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED => true ,
221- cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED => false ,
218+ cublas :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED => true ,
219+ cublas :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED => false ,
222220 } )
223221 }
224222 }
@@ -238,9 +236,9 @@ impl CublasContext {
238236 /// ```
239237 pub fn set_math_mode ( & self , math_mode : MathMode ) -> Result < ( ) > {
240238 unsafe {
241- Ok ( cublas_sys :: cublasSetMathMode (
239+ Ok ( cublas :: cublasSetMathMode (
242240 self . raw ,
243- mem:: transmute :: < u32 , cublas_sys :: cublasMath_t > ( math_mode. bits ( ) ) ,
241+ mem:: transmute :: < u32 , cublas :: cublasMath_t > ( math_mode. bits ( ) ) ,
244242 )
245243 . to_result ( ) ?)
246244 }
@@ -263,7 +261,7 @@ impl CublasContext {
263261 pub fn get_math_mode ( & self ) -> Result < MathMode > {
264262 let mut mode = MaybeUninit :: uninit ( ) ;
265263 unsafe {
266- cublas_sys :: cublasGetMathMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
264+ cublas :: cublasGetMathMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
267265 Ok ( MathMode :: from_bits ( mode. assume_init ( ) as u32 )
268266 . expect ( "Invalid MathMode from cuBLAS" ) )
269267 }
@@ -303,7 +301,7 @@ impl CublasContext {
303301 let path = log_file_name. map ( |p| CString :: new ( p) . expect ( "nul in log_file_name" ) ) ;
304302 let path_ptr = path. map_or ( ptr:: null ( ) , |s| s. as_ptr ( ) ) ;
305303
306- cublas_sys :: cublasLoggerConfigure (
304+ cublas :: cublasLoggerConfigure (
307305 enable as i32 ,
308306 log_to_stdout as i32 ,
309307 log_to_stderr as i32 ,
@@ -320,7 +318,7 @@ impl CublasContext {
320318 ///
321319 /// The callback must not panic and unwind.
322320 pub unsafe fn set_logger_callback ( callback : Option < unsafe extern "C" fn ( * const c_char ) > ) {
323- cublas_sys :: cublasSetLoggerCallback ( callback)
321+ cublas :: cublasSetLoggerCallback ( callback)
324322 . to_result ( )
325323 . unwrap ( ) ;
326324 }
@@ -329,7 +327,7 @@ impl CublasContext {
329327 pub fn get_logger_callback ( ) -> Option < unsafe extern "C" fn ( * const c_char ) > {
330328 let mut cb = MaybeUninit :: uninit ( ) ;
331329 unsafe {
332- cublas_sys :: cublasGetLoggerCallback ( cb. as_mut_ptr ( ) )
330+ cublas :: cublasGetLoggerCallback ( cb. as_mut_ptr ( ) )
333331 . to_result ( )
334332 . unwrap ( ) ;
335333 cb. assume_init ( )
@@ -340,7 +338,7 @@ impl CublasContext {
340338impl Drop for CublasContext {
341339 fn drop ( & mut self ) {
342340 unsafe {
343- let _ = cublas_sys :: cublasDestroy ( self . raw ) ;
341+ let _ = cublas :: cublasDestroy ( self . raw ) ;
344342 }
345343 }
346344}
0 commit comments