Skip to content

Commit 3c31d3d

Browse files
committed
feat: dlopen Enzyme
1 parent a591113 commit 3c31d3d

File tree

13 files changed

+539
-242
lines changed

13 files changed

+539
-242
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3574,6 +3574,7 @@ dependencies = [
35743574
"gimli 0.31.1",
35753575
"itertools",
35763576
"libc",
3577+
"libloading 0.9.0",
35773578
"measureme",
35783579
"object 0.37.3",
35793580
"rustc-demangle",

compiler/rustc_codegen_llvm/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ bitflags = "2.4.1"
1414
gimli = "0.31"
1515
itertools = "0.12"
1616
libc = "0.2"
17+
libloading = "0.9.0"
1718
measureme = "12.0.1"
1819
object = { version = "0.37.0", default-features = false, features = ["std", "read"] }
1920
rustc-demangle = "0.1.21"

compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -528,31 +528,37 @@ fn thin_lto(
528528
}
529529
}
530530

531-
fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
531+
#[cfg(feature = "llvm_enzyme")]
532+
pub(crate) fn enable_autodiff_settings(
533+
sysroot: &rustc_session::config::Sysroot,
534+
ad: &[config::AutoDiff],
535+
) {
536+
let mut enzyme = llvm::EnzymeWrapper::get_or_init(sysroot);
537+
532538
for val in ad {
533539
// We intentionally don't use a wildcard, to not forget handling anything new.
534540
match val {
535541
config::AutoDiff::PrintPerf => {
536-
llvm::set_print_perf(true);
542+
enzyme.set_print_perf(true);
537543
}
538544
config::AutoDiff::PrintAA => {
539-
llvm::set_print_activity(true);
545+
enzyme.set_print_activity(true);
540546
}
541547
config::AutoDiff::PrintTA => {
542-
llvm::set_print_type(true);
548+
enzyme.set_print_type(true);
543549
}
544550
config::AutoDiff::PrintTAFn(fun) => {
545-
llvm::set_print_type(true); // Enable general type printing
546-
llvm::set_print_type_fun(&fun); // Set specific function to analyze
551+
enzyme.set_print_type(true); // Enable general type printing
552+
enzyme.set_print_type_fun(&fun); // Set specific function to analyze
547553
}
548554
config::AutoDiff::Inline => {
549-
llvm::set_inline(true);
555+
enzyme.set_inline(true);
550556
}
551557
config::AutoDiff::LooseTypes => {
552-
llvm::set_loose_types(true);
558+
enzyme.set_loose_types(true);
553559
}
554560
config::AutoDiff::PrintSteps => {
555-
llvm::set_print(true);
561+
enzyme.set_print(true);
556562
}
557563
// We handle this in the PassWrapper.cpp
558564
config::AutoDiff::PrintPasses => {}
@@ -571,9 +577,9 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
571577
}
572578
}
573579
// This helps with handling enums for now.
574-
llvm::set_strict_aliasing(false);
580+
enzyme.set_strict_aliasing(false);
575581
// FIXME(ZuseZ4): Test this, since it was added a long time ago.
576-
llvm::set_rust_rules(true);
582+
enzyme.set_rust_rules(true);
577583
}
578584

579585
pub(crate) fn run_pass_manager(
@@ -608,10 +614,6 @@ pub(crate) fn run_pass_manager(
608614
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD }
609615
};
610616

611-
if enable_ad {
612-
enable_autodiff_settings(&config.autodiff);
613-
}
614-
615617
unsafe {
616618
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage);
617619
}

compiler/rustc_codegen_llvm/src/back/write.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -726,6 +726,13 @@ pub(crate) unsafe fn llvm_optimize(
726726

727727
let llvm_plugins = config.llvm_plugins.join(",");
728728

729+
let enzyme_fn = if consider_ad {
730+
let wrapper = llvm::EnzymeWrapper::get_or_init(&cgcx.sysroot);
731+
wrapper.registerEnzymeAndPassPipeline
732+
} else {
733+
std::ptr::null()
734+
};
735+
729736
let result = unsafe {
730737
llvm::LLVMRustOptimize(
731738
module.module_llvm.llmod(),
@@ -745,7 +752,7 @@ pub(crate) unsafe fn llvm_optimize(
745752
vectorize_loop,
746753
config.no_builtins,
747754
config.emit_lifetime_markers,
748-
run_enzyme,
755+
enzyme_fn,
749756
print_before_enzyme,
750757
print_after_enzyme,
751758
print_passes,

compiler/rustc_codegen_llvm/src/lib.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,18 @@ impl CodegenBackend for LlvmCodegenBackend {
240240

241241
fn init(&self, sess: &Session) {
242242
llvm_util::init(sess); // Make sure llvm is inited
243+
244+
#[cfg(feature = "llvm_enzyme")]
245+
{
246+
use rustc_session::config::AutoDiff;
247+
if sess.opts.unstable_opts.autodiff.contains(&AutoDiff::Enable) {
248+
{
249+
use crate::back::lto::enable_autodiff_settings;
250+
251+
enable_autodiff_settings(&sess.opts.sysroot, &sess.opts.unstable_opts.autodiff);
252+
}
253+
}
254+
}
243255
}
244256

245257
fn provide(&self, providers: &mut Providers) {

0 commit comments

Comments
 (0)