@@ -97,10 +97,8 @@ pub(crate) use self::Enzyme_AD::*;
9797#[ cfg( feature = "llvm_enzyme" ) ]
9898pub ( crate ) mod Enzyme_AD {
9999 use std:: ffi:: { c_char, c_void} ;
100- use std:: sync:: { Mutex , OnceLock } ;
100+ use std:: sync:: { Mutex , MutexGuard , OnceLock } ;
101101
102- use rustc_codegen_ssa:: back:: write:: CodegenContext ;
103- use rustc_codegen_ssa:: traits:: WriteBackendMethods ;
104102 use rustc_middle:: bug;
105103 use rustc_session:: config:: { Sysroot , host_tuple} ;
106104 use rustc_session:: filesearch;
@@ -202,18 +200,29 @@ pub(crate) mod Enzyme_AD {
202200 static ENZYME_INSTANCE : OnceLock < Mutex < EnzymeWrapper > > = OnceLock :: new ( ) ;
203201
204202 impl EnzymeWrapper {
205- pub ( crate ) fn init < ' a , B : WriteBackendMethods > (
206- cgcx : & ' a CodegenContext < B > ,
207- ) -> & ' static Mutex < Self > {
208- ENZYME_INSTANCE . get_or_init ( || {
209- Self :: call_dynamic ( cgcx)
210- . unwrap_or_else ( |e| bug ! ( "failed to load Enzyme: {e}" ) )
211- . into ( )
212- } )
203+ /// Initialize EnzymeWrapper with the given sysroot if not already initialized.
204+ /// Safe to call multiple times - subsequent calls are no-ops due to OnceLock.
205+ pub ( crate ) fn get_or_init (
206+ sysroot : & rustc_session:: config:: Sysroot ,
207+ ) -> MutexGuard < ' static , Self > {
208+ ENZYME_INSTANCE
209+ . get_or_init ( || {
210+ Self :: call_dynamic ( sysroot)
211+ . unwrap_or_else ( |e| bug ! ( "failed to load Enzyme: {e}" ) )
212+ . into ( )
213+ } )
214+ . lock ( )
215+ . unwrap ( )
213216 }
214217
215- pub ( crate ) fn get_instance ( ) -> & ' static Mutex < Self > {
216- ENZYME_INSTANCE . get ( ) . expect ( "EnzymeWrapper not initialized" )
218+ /// Get the EnzymeWrapper instance. Panics if not initialized.
219+ /// Call get_or_init with a sysroot first.
220+ pub ( crate ) fn get_instance ( ) -> MutexGuard < ' static , Self > {
221+ ENZYME_INSTANCE
222+ . get ( )
223+ . expect ( "EnzymeWrapper not initialized. Call get_or_init with sysroot first." )
224+ . lock ( )
225+ . unwrap ( )
217226 }
218227
219228 pub ( crate ) fn new_type_tree ( & self ) -> CTypeTreeRef {
@@ -350,10 +359,10 @@ pub(crate) mod Enzyme_AD {
350359 }
351360
352361 #[ allow( non_snake_case) ]
353- fn call_dynamic < ' a , B : WriteBackendMethods > (
354- cgcx : & ' a CodegenContext < B > ,
362+ fn call_dynamic (
363+ sysroot : & rustc_session :: config :: Sysroot ,
355364 ) -> Result < Self , Box < dyn std:: error:: Error > > {
356- let enzyme_path = Self :: get_enzyme_path ( & cgcx . sysroot ) ?;
365+ let enzyme_path = Self :: get_enzyme_path ( sysroot) ?;
357366 let lib = unsafe { libloading:: Library :: new ( enzyme_path) ? } ;
358367
359368 load_ptrs_by_symbols_fn ! (
@@ -587,19 +596,19 @@ pub(crate) mod Fallback_AD {
587596impl TypeTree {
588597 pub ( crate ) fn new ( ) -> TypeTree {
589598 let wrapper = EnzymeWrapper :: get_instance ( ) ;
590- let inner = wrapper. lock ( ) . unwrap ( ) . new_type_tree ( ) ;
599+ let inner = wrapper. new_type_tree ( ) ;
591600 TypeTree { inner }
592601 }
593602
594603 pub ( crate ) fn from_type ( t : CConcreteType , ctx : & Context ) -> TypeTree {
595604 let wrapper = EnzymeWrapper :: get_instance ( ) ;
596- let inner = wrapper. lock ( ) . unwrap ( ) . new_type_tree_ct ( t, ctx) ;
605+ let inner = wrapper. new_type_tree_ct ( t, ctx) ;
597606 TypeTree { inner }
598607 }
599608
600609 pub ( crate ) fn merge ( self , other : Self ) -> Self {
601610 let wrapper = EnzymeWrapper :: get_instance ( ) ;
602- wrapper. lock ( ) . unwrap ( ) . merge_type_tree ( self . inner , other. inner ) ;
611+ wrapper. merge_type_tree ( self . inner , other. inner ) ;
603612 drop ( other) ;
604613 self
605614 }
@@ -614,7 +623,7 @@ impl TypeTree {
614623 ) -> Self {
615624 let layout = std:: ffi:: CString :: new ( layout) . unwrap ( ) ;
616625 let wrapper = EnzymeWrapper :: get_instance ( ) ;
617- wrapper. lock ( ) . unwrap ( ) . shift_indicies_eq (
626+ wrapper. shift_indicies_eq (
618627 self . inner ,
619628 layout. as_ptr ( ) ,
620629 offset as i64 ,
@@ -627,36 +636,30 @@ impl TypeTree {
627636
628637 pub ( crate ) fn insert ( & mut self , indices : & [ i64 ] , ct : CConcreteType , ctx : & Context ) {
629638 let wrapper = EnzymeWrapper :: get_instance ( ) ;
630- wrapper. lock ( ) . unwrap ( ) . tree_insert_eq (
631- self . inner ,
632- indices. as_ptr ( ) ,
633- indices. len ( ) ,
634- ct,
635- ctx,
636- ) ;
639+ wrapper. tree_insert_eq ( self . inner , indices. as_ptr ( ) , indices. len ( ) , ct, ctx) ;
637640 }
638641}
639642
640643impl Clone for TypeTree {
641644 fn clone ( & self ) -> Self {
642645 let wrapper = EnzymeWrapper :: get_instance ( ) ;
643- let inner = wrapper. lock ( ) . unwrap ( ) . new_type_tree_tr ( self . inner ) ;
646+ let inner = wrapper. new_type_tree_tr ( self . inner ) ;
644647 TypeTree { inner }
645648 }
646649}
647650
648651impl std:: fmt:: Display for TypeTree {
649652 fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
650653 let wrapper = EnzymeWrapper :: get_instance ( ) ;
651- let ptr = wrapper. lock ( ) . unwrap ( ) . tree_to_string ( self . inner ) ;
654+ let ptr = wrapper. tree_to_string ( self . inner ) ;
652655 let cstr = unsafe { std:: ffi:: CStr :: from_ptr ( ptr) } ;
653656 match cstr. to_str ( ) {
654657 Ok ( x) => write ! ( f, "{}" , x) ?,
655658 Err ( err) => write ! ( f, "could not parse: {}" , err) ?,
656659 }
657660
658661 // delete C string pointer
659- wrapper. lock ( ) . unwrap ( ) . tree_to_string_free ( ptr) ;
662+ wrapper. tree_to_string_free ( ptr) ;
660663
661664 Ok ( ( ) )
662665 }
@@ -671,6 +674,6 @@ impl std::fmt::Debug for TypeTree {
671674impl Drop for TypeTree {
672675 fn drop ( & mut self ) {
673676 let wrapper = EnzymeWrapper :: get_instance ( ) ;
674- wrapper. lock ( ) . unwrap ( ) . free_type_tree ( self . inner )
677+ wrapper. free_type_tree ( self . inner )
675678 }
676679}
0 commit comments