Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -3574,6 +3574,7 @@ dependencies = [
"gimli 0.31.1",
"itertools",
"libc",
"libloading 0.9.0",
"measureme",
"object 0.37.3",
"rustc-demangle",
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_llvm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ bitflags = "2.4.1"
gimli = "0.31"
itertools = "0.12"
libc = "0.2"
libloading = "0.9.0"
measureme = "12.0.1"
object = { version = "0.37.0", default-features = false, features = ["std", "read"] }
rustc-demangle = "0.1.21"
Expand Down
28 changes: 16 additions & 12 deletions compiler/rustc_codegen_llvm/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,31 +528,35 @@ fn thin_lto(
}
}

fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
fn enable_autodiff_settings(cgcx: &CodegenContext<LlvmCodegenBackend>, ad: &[config::AutoDiff]) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call this from LlvmCodegenBackend::init() instead? There you have access to the full Session.

// Initialize Enzyme if not already done (idempotent due to OnceLock)
// This ensures it works even if LlvmCodegenBackend::init() didn't run it
let mut enzyme = llvm::EnzymeWrapper::get_or_init(&cgcx.sysroot);

for val in ad {
// We intentionally don't use a wildcard, to not forget handling anything new.
match val {
config::AutoDiff::PrintPerf => {
llvm::set_print_perf(true);
enzyme.set_print_perf(true);
}
config::AutoDiff::PrintAA => {
llvm::set_print_activity(true);
enzyme.set_print_activity(true);
}
config::AutoDiff::PrintTA => {
llvm::set_print_type(true);
enzyme.set_print_type(true);
}
config::AutoDiff::PrintTAFn(fun) => {
llvm::set_print_type(true); // Enable general type printing
llvm::set_print_type_fun(&fun); // Set specific function to analyze
enzyme.set_print_type(true); // Enable general type printing
enzyme.set_print_type_fun(&fun); // Set specific function to analyze
}
config::AutoDiff::Inline => {
llvm::set_inline(true);
enzyme.set_inline(true);
}
config::AutoDiff::LooseTypes => {
llvm::set_loose_types(true);
enzyme.set_loose_types(true);
}
config::AutoDiff::PrintSteps => {
llvm::set_print(true);
enzyme.set_print(true);
}
// We handle this in the PassWrapper.cpp
config::AutoDiff::PrintPasses => {}
Expand All @@ -571,9 +575,9 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
}
}
// This helps with handling enums for now.
llvm::set_strict_aliasing(false);
enzyme.set_strict_aliasing(false);
// FIXME(ZuseZ4): Test this, since it was added a long time ago.
llvm::set_rust_rules(true);
enzyme.set_rust_rules(true);
}

pub(crate) fn run_pass_manager(
Expand Down Expand Up @@ -609,7 +613,7 @@ pub(crate) fn run_pass_manager(
};

if enable_ad {
enable_autodiff_settings(&config.autodiff);
enable_autodiff_settings(&cgcx, &config.autodiff);
}

unsafe {
Expand Down
10 changes: 9 additions & 1 deletion compiler/rustc_codegen_llvm/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,14 @@ pub(crate) unsafe fn llvm_optimize(

let llvm_plugins = config.llvm_plugins.join(",");

let enzyme_fn = if consider_ad {
// Initialize Enzyme if not already done (idempotent due to OnceLock)
let wrapper = llvm::EnzymeWrapper::get_or_init(&cgcx.sysroot);
wrapper.registerEnzymeAndPassPipeline
} else {
std::ptr::null()
};

let result = unsafe {
llvm::LLVMRustOptimize(
module.module_llvm.llmod(),
Expand All @@ -745,7 +753,7 @@ pub(crate) unsafe fn llvm_optimize(
vectorize_loop,
config.no_builtins,
config.emit_lifetime_markers,
run_enzyme,
enzyme_fn,
print_before_enzyme,
print_after_enzyme,
print_passes,
Expand Down
8 changes: 7 additions & 1 deletion compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1099,7 +1099,13 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
// vs. copying a struct with mixed types requires different derivative handling.
// The TypeTree tells Enzyme exactly what memory layout to expect.
if let Some(tt) = tt {
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
crate::typetree::add_tt(
self.cx().llmod,
self.cx().llcx,
memcpy,
tt,
&self.cx().tcx.sess.opts.sysroot,
);
}
}

Expand Down
8 changes: 7 additions & 1 deletion compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,13 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
);

if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() {
crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree);
crate::typetree::add_tt(
cx.llmod,
cx.llcx,
fn_to_diff,
fnc_tree,
&builder.cx().tcx.sess.opts.sysroot,
);
}

let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
Expand Down
9 changes: 8 additions & 1 deletion compiler/rustc_codegen_llvm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ use rustc_middle::dep_graph::{WorkProduct, WorkProductId};
use rustc_middle::ty::TyCtxt;
use rustc_middle::util::Providers;
use rustc_session::Session;
use rustc_session::config::{OptLevel, OutputFilenames, PrintKind, PrintRequest};
use rustc_session::config::{AutoDiff, OptLevel, OutputFilenames, PrintKind, PrintRequest};
use rustc_span::Symbol;
use rustc_target::spec::{RelocModel, TlsModel};

Expand Down Expand Up @@ -240,6 +240,13 @@ impl CodegenBackend for LlvmCodegenBackend {

fn init(&self, sess: &Session) {
llvm_util::init(sess); // Make sure llvm is inited

if sess.opts.unstable_opts.autodiff.contains(&AutoDiff::Enable) {
#[cfg(feature = "llvm_enzyme")]
{
drop(llvm::EnzymeWrapper::get_or_init(&sess.opts.sysroot));
}
}
}

fn provide(&self, providers: &mut Providers) {
Expand Down
Loading
Loading