Skip to content

Commit 52f0d2e

Browse files
committed
EnzymeWrapper::get_instance to return MutexGuard, EnzymeWrapper::get_or_init takes only sysroot, not a whole session or context
1 parent 5b6aa75 commit 52f0d2e

File tree

8 files changed

+77
-52
lines changed

8 files changed

+77
-52
lines changed

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -529,33 +529,34 @@ fn thin_lto(
529529
}
530530

531531
fn enable_autodiff_settings(cgcx: &CodegenContext<LlvmCodegenBackend>, ad: &[config::AutoDiff]) {
532-
use std::sync::Mutex;
533-
let enzyme: &'static Mutex<llvm::EnzymeWrapper> = llvm::EnzymeWrapper::init(cgcx);
532+
// Initialize Enzyme if not already done (idempotent due to OnceLock)
533+
// This ensures it works even if LlvmCodegenBackend::init() didn't run it
534+
let mut enzyme = llvm::EnzymeWrapper::get_or_init(&cgcx.sysroot);
534535

535536
for val in ad {
536537
// We intentionally don't use a wildcard, to not forget handling anything new.
537538
match val {
538539
config::AutoDiff::PrintPerf => {
539-
enzyme.lock().unwrap().set_print_perf(true);
540+
enzyme.set_print_perf(true);
540541
}
541542
config::AutoDiff::PrintAA => {
542-
enzyme.lock().unwrap().set_print_activity(true);
543+
enzyme.set_print_activity(true);
543544
}
544545
config::AutoDiff::PrintTA => {
545-
enzyme.lock().unwrap().set_print_type(true);
546+
enzyme.set_print_type(true);
546547
}
547548
config::AutoDiff::PrintTAFn(fun) => {
548-
enzyme.lock().unwrap().set_print_type(true); // Enable general type printing
549-
enzyme.lock().unwrap().set_print_type_fun(&fun); // Set specific function to analyze
549+
enzyme.set_print_type(true); // Enable general type printing
550+
enzyme.set_print_type_fun(&fun); // Set specific function to analyze
550551
}
551552
config::AutoDiff::Inline => {
552-
enzyme.lock().unwrap().set_inline(true);
553+
enzyme.set_inline(true);
553554
}
554555
config::AutoDiff::LooseTypes => {
555-
enzyme.lock().unwrap().set_loose_types(true);
556+
enzyme.set_loose_types(true);
556557
}
557558
config::AutoDiff::PrintSteps => {
558-
enzyme.lock().unwrap().set_print(true);
559+
enzyme.set_print(true);
559560
}
560561
// We handle this in the PassWrapper.cpp
561562
config::AutoDiff::PrintPasses => {}
@@ -574,9 +575,9 @@ fn enable_autodiff_settings(cgcx: &CodegenContext<LlvmCodegenBackend>, ad: &[con
574575
}
575576
}
576577
// This helps with handling enums for now.
577-
enzyme.lock().unwrap().set_strict_aliasing(false);
578+
enzyme.set_strict_aliasing(false);
578579
// FIXME(ZuseZ4): Test this, since it was added a long time ago.
579-
enzyme.lock().unwrap().set_rust_rules(true);
580+
enzyme.set_rust_rules(true);
580581
}
581582

582583
pub(crate) fn run_pass_manager(

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::ffi::{CStr, CString};
22
use std::io::{self, Write};
33
use std::path::{Path, PathBuf};
44
use std::ptr::null_mut;
5-
use std::sync::{Arc, Mutex};
5+
use std::sync::Arc;
66
use std::{fs, slice, str};
77

88
use libc::{c_char, c_int, c_void, size_t};
@@ -727,8 +727,9 @@ pub(crate) unsafe fn llvm_optimize(
727727
let llvm_plugins = config.llvm_plugins.join(",");
728728

729729
let enzyme_fn = if consider_ad {
730-
let wrapper: &'static Mutex<llvm::EnzymeWrapper> = llvm::EnzymeWrapper::init(cgcx);
731-
wrapper.lock().unwrap().registerEnzymeAndPassPipeline
730+
// Initialize Enzyme if not already done (idempotent due to OnceLock)
731+
let wrapper = llvm::EnzymeWrapper::get_or_init(&cgcx.sysroot);
732+
wrapper.registerEnzymeAndPassPipeline
732733
} else {
733734
std::ptr::null()
734735
};

compiler/rustc_codegen_llvm/src/builder.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1099,7 +1099,13 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
10991099
// vs. copying a struct with mixed types requires different derivative handling.
11001100
// The TypeTree tells Enzyme exactly what memory layout to expect.
11011101
if let Some(tt) = tt {
1102-
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
1102+
crate::typetree::add_tt(
1103+
self.cx().llmod,
1104+
self.cx().llcx,
1105+
memcpy,
1106+
tt,
1107+
&self.cx().tcx.sess.opts.sysroot,
1108+
);
11031109
}
11041110
}
11051111

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,13 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
375375
);
376376

377377
if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() {
378-
crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree);
378+
crate::typetree::add_tt(
379+
cx.llmod,
380+
cx.llcx,
381+
fn_to_diff,
382+
fnc_tree,
383+
&builder.cx().tcx.sess.opts.sysroot,
384+
);
379385
}
380386

381387
let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);

compiler/rustc_codegen_llvm/src/lib.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ use rustc_middle::dep_graph::{WorkProduct, WorkProductId};
4141
use rustc_middle::ty::TyCtxt;
4242
use rustc_middle::util::Providers;
4343
use rustc_session::Session;
44-
use rustc_session::config::{OptLevel, OutputFilenames, PrintKind, PrintRequest};
44+
use rustc_session::config::{AutoDiff, OptLevel, OutputFilenames, PrintKind, PrintRequest};
4545
use rustc_span::Symbol;
4646
use rustc_target::spec::{RelocModel, TlsModel};
4747

@@ -240,6 +240,13 @@ impl CodegenBackend for LlvmCodegenBackend {
240240

241241
fn init(&self, sess: &Session) {
242242
llvm_util::init(sess); // Make sure llvm is inited
243+
244+
if sess.opts.unstable_opts.autodiff.contains(&AutoDiff::Enable) {
245+
#[cfg(feature = "llvm_enzyme")]
246+
{
247+
drop(llvm::EnzymeWrapper::get_or_init(&sess.opts.sysroot));
248+
}
249+
}
243250
}
244251

245252
fn provide(&self, providers: &mut Providers) {

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 34 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,8 @@ pub(crate) use self::Enzyme_AD::*;
9797
#[cfg(feature = "llvm_enzyme")]
9898
pub(crate) mod Enzyme_AD {
9999
use std::ffi::{c_char, c_void};
100-
use std::sync::{Mutex, OnceLock};
100+
use std::sync::{Mutex, MutexGuard, OnceLock};
101101

102-
use rustc_codegen_ssa::back::write::CodegenContext;
103-
use rustc_codegen_ssa::traits::WriteBackendMethods;
104102
use rustc_middle::bug;
105103
use rustc_session::config::{Sysroot, host_tuple};
106104
use rustc_session::filesearch;
@@ -202,18 +200,29 @@ pub(crate) mod Enzyme_AD {
202200
static ENZYME_INSTANCE: OnceLock<Mutex<EnzymeWrapper>> = OnceLock::new();
203201

204202
impl EnzymeWrapper {
205-
pub(crate) fn init<'a, B: WriteBackendMethods>(
206-
cgcx: &'a CodegenContext<B>,
207-
) -> &'static Mutex<Self> {
208-
ENZYME_INSTANCE.get_or_init(|| {
209-
Self::call_dynamic(cgcx)
210-
.unwrap_or_else(|e| bug!("failed to load Enzyme: {e}"))
211-
.into()
212-
})
203+
/// Initialize EnzymeWrapper with the given sysroot if not already initialized.
204+
/// Safe to call multiple times - subsequent calls are no-ops due to OnceLock.
205+
pub(crate) fn get_or_init(
206+
sysroot: &rustc_session::config::Sysroot,
207+
) -> MutexGuard<'static, Self> {
208+
ENZYME_INSTANCE
209+
.get_or_init(|| {
210+
Self::call_dynamic(sysroot)
211+
.unwrap_or_else(|e| bug!("failed to load Enzyme: {e}"))
212+
.into()
213+
})
214+
.lock()
215+
.unwrap()
213216
}
214217

215-
pub(crate) fn get_instance() -> &'static Mutex<Self> {
216-
ENZYME_INSTANCE.get().expect("EnzymeWrapper not initialized")
218+
/// Get the EnzymeWrapper instance. Panics if not initialized.
219+
/// Call get_or_init with a sysroot first.
220+
pub(crate) fn get_instance() -> MutexGuard<'static, Self> {
221+
ENZYME_INSTANCE
222+
.get()
223+
.expect("EnzymeWrapper not initialized. Call get_or_init with sysroot first.")
224+
.lock()
225+
.unwrap()
217226
}
218227

219228
pub(crate) fn new_type_tree(&self) -> CTypeTreeRef {
@@ -350,10 +359,10 @@ pub(crate) mod Enzyme_AD {
350359
}
351360

352361
#[allow(non_snake_case)]
353-
fn call_dynamic<'a, B: WriteBackendMethods>(
354-
cgcx: &'a CodegenContext<B>,
362+
fn call_dynamic(
363+
sysroot: &rustc_session::config::Sysroot,
355364
) -> Result<Self, Box<dyn std::error::Error>> {
356-
let enzyme_path = Self::get_enzyme_path(&cgcx.sysroot)?;
365+
let enzyme_path = Self::get_enzyme_path(sysroot)?;
357366
let lib = unsafe { libloading::Library::new(enzyme_path)? };
358367

359368
load_ptrs_by_symbols_fn!(
@@ -587,19 +596,19 @@ pub(crate) mod Fallback_AD {
587596
impl TypeTree {
588597
pub(crate) fn new() -> TypeTree {
589598
let wrapper = EnzymeWrapper::get_instance();
590-
let inner = wrapper.lock().unwrap().new_type_tree();
599+
let inner = wrapper.new_type_tree();
591600
TypeTree { inner }
592601
}
593602

594603
pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree {
595604
let wrapper = EnzymeWrapper::get_instance();
596-
let inner = wrapper.lock().unwrap().new_type_tree_ct(t, ctx);
605+
let inner = wrapper.new_type_tree_ct(t, ctx);
597606
TypeTree { inner }
598607
}
599608

600609
pub(crate) fn merge(self, other: Self) -> Self {
601610
let wrapper = EnzymeWrapper::get_instance();
602-
wrapper.lock().unwrap().merge_type_tree(self.inner, other.inner);
611+
wrapper.merge_type_tree(self.inner, other.inner);
603612
drop(other);
604613
self
605614
}
@@ -614,7 +623,7 @@ impl TypeTree {
614623
) -> Self {
615624
let layout = std::ffi::CString::new(layout).unwrap();
616625
let wrapper = EnzymeWrapper::get_instance();
617-
wrapper.lock().unwrap().shift_indicies_eq(
626+
wrapper.shift_indicies_eq(
618627
self.inner,
619628
layout.as_ptr(),
620629
offset as i64,
@@ -627,36 +636,30 @@ impl TypeTree {
627636

628637
pub(crate) fn insert(&mut self, indices: &[i64], ct: CConcreteType, ctx: &Context) {
629638
let wrapper = EnzymeWrapper::get_instance();
630-
wrapper.lock().unwrap().tree_insert_eq(
631-
self.inner,
632-
indices.as_ptr(),
633-
indices.len(),
634-
ct,
635-
ctx,
636-
);
639+
wrapper.tree_insert_eq(self.inner, indices.as_ptr(), indices.len(), ct, ctx);
637640
}
638641
}
639642

640643
impl Clone for TypeTree {
641644
fn clone(&self) -> Self {
642645
let wrapper = EnzymeWrapper::get_instance();
643-
let inner = wrapper.lock().unwrap().new_type_tree_tr(self.inner);
646+
let inner = wrapper.new_type_tree_tr(self.inner);
644647
TypeTree { inner }
645648
}
646649
}
647650

648651
impl std::fmt::Display for TypeTree {
649652
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
650653
let wrapper = EnzymeWrapper::get_instance();
651-
let ptr = wrapper.lock().unwrap().tree_to_string(self.inner);
654+
let ptr = wrapper.tree_to_string(self.inner);
652655
let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) };
653656
match cstr.to_str() {
654657
Ok(x) => write!(f, "{}", x)?,
655658
Err(err) => write!(f, "could not parse: {}", err)?,
656659
}
657660

658661
// delete C string pointer
659-
wrapper.lock().unwrap().tree_to_string_free(ptr);
662+
wrapper.tree_to_string_free(ptr);
660663

661664
Ok(())
662665
}
@@ -671,6 +674,6 @@ impl std::fmt::Debug for TypeTree {
671674
impl Drop for TypeTree {
672675
fn drop(&mut self) {
673676
let wrapper = EnzymeWrapper::get_instance();
674-
wrapper.lock().unwrap().free_type_tree(self.inner)
677+
wrapper.free_type_tree(self.inner)
675678
}
676679
}

compiler/rustc_codegen_llvm/src/typetree.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ pub(crate) fn add_tt<'ll>(
6363
llcx: &'ll llvm::Context,
6464
fn_def: &'ll Value,
6565
tt: FncTree,
66+
sysroot: &rustc_session::config::Sysroot,
6667
) {
6768
let inputs = tt.args;
6869
let ret_tt: RustTypeTree = tt.ret;
@@ -75,7 +76,7 @@ pub(crate) fn add_tt<'ll>(
7576
let attr_name = "enzyme_type";
7677
let c_attr_name = CString::new(attr_name).unwrap();
7778

78-
let enzyme_wrapper = EnzymeWrapper::get_instance().lock().unwrap();
79+
let enzyme_wrapper = EnzymeWrapper::get_or_init(sysroot);
7980

8081
for (i, input) in inputs.iter().enumerate() {
8182
unsafe {
@@ -120,6 +121,7 @@ pub(crate) fn add_tt<'ll>(
120121
_llcx: &'ll llvm::Context,
121122
_fn_def: &'ll Value,
122123
_tt: FncTree,
124+
_sysroot: &rustc_session::config::Sysroot,
123125
) {
124126
unimplemented!()
125127
}

compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -888,7 +888,7 @@ extern "C" LLVMRustResult LLVMRustOptimize(
888888
}
889889

890890
// now load "-enzyme" pass:
891-
#ifdef ENZYME
891+
// With dlopen, ENZYME macro may not be defined, so check EnzymePtr directly
892892
if (EnzymePtr) {
893893

894894
if (PrintBeforeEnzyme) {
@@ -910,7 +910,6 @@ extern "C" LLVMRustResult LLVMRustOptimize(
910910
MPM.addPass(PrintModulePass(outs(), Banner, true, false));
911911
}
912912
}
913-
#endif
914913
if (PrintPasses) {
915914
// Print all passes from the PM:
916915
std::string Pipeline;

0 commit comments

Comments
 (0)