Skip to content

Prevent ABI changes affect EnzymeAD #142544

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions compiler/rustc_monomorphize/src/partitioning/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,23 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec
continue;
}
}

let pci = PseudoCanonicalInput { typing_env: TypingEnv::fully_monomorphized(), value: *ty };

let layout = match tcx.layout_of(pci) {
Ok(layout) => layout.layout,
Err(_) => {
bug!("failed to compute layout for type {:?}", ty);
}
};

// If the argument is lowered as a `ScalarPair`, we need to duplicate its activity.
// Otherwise, the number of activities won't match the number of LLVM arguments and
// this will lead to errors when verifying the Enzyme call.
if let rustc_abi::BackendRepr::ScalarPair(_, _) = layout.backend_repr() {
new_activities.push(da[i].clone());
new_positions.push(i + 1);
}
}
// now add the extra activities coming from slices
// Reverse order to not invalidate the indices
Expand Down
331 changes: 331 additions & 0 deletions tests/codegen/autodiff/abi_handling.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,331 @@
//@ revisions: debug release

//@[debug] compile-flags: -Zautodiff=Enable -C opt-level=0 -Clto=fat
//@[release] compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme

// This test checks that Rust types are lowered to LLVM-IR types in a way
// we expect and Enzyme can handle. We explicitly check release mode to
// ensure that LLVM's O3 pipeline doesn't rewrite function signatures
// into forms that Enzyme can't process correctly.

#![feature(autodiff)]

use std::autodiff::{autodiff_forward, autodiff_reverse};

#[derive(Copy, Clone)]
struct Input {
x: f32,
y: f32,
}

#[derive(Copy, Clone)]
struct Wrapper {
z: f32,
}

#[derive(Copy, Clone)]
struct NestedInput {
x: f32,
y: Wrapper,
}

fn square(x: f32) -> f32 {
x * x
}

// CHECK: ; abi_handling::f1
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
// debug-NEXT: define internal float @_ZN12abi_handling2f117h536ac8081c1e4101E
// debug-SAME: (ptr align 4 %x)
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f117h536ac8081c1e4101E
// release-SAME: (float %x.0.val, float %x.4.val)
#[autodiff_forward(df1, Dual, Dual)]
fn f1(x: &[f32; 2]) -> f32 {
x[0] + x[1]
}

// CHECK: ; abi_handling::f2
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
// debug-NEXT: define internal float @_ZN12abi_handling2f217h33732e9f83c91bc9E
// debug-SAME: (ptr %f, float %x)
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f217h33732e9f83c91bc9E
// release-SAME: (float noundef %x)
#[autodiff_reverse(df2, Const, Active, Active)]
fn f2(f: fn(f32) -> f32, x: f32) -> f32 {
f(x)
}

// CHECK: ; abi_handling::f3
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
// debug-NEXT: define internal float @_ZN12abi_handling2f317h9cd1fc602b0815a4E
// debug-SAME: (ptr align 4 %x, ptr align 4 %y)
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f317h9cd1fc602b0815a4E
// release-SAME: (float %x.0.val)
#[autodiff_forward(df3, Dual, Dual, Dual)]
fn f3<'a>(x: &'a f32, y: &'a f32) -> f32 {
*x * *y
}

// CHECK: ; abi_handling::f4
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
// debug-NEXT: define internal float @_ZN12abi_handling2f417h2f4a9a7492d91e9fE
// debug-SAME: (float %x.0, float %x.1)
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f417h2f4a9a7492d91e9fE
// release-SAME: (float noundef %x.0, float noundef %x.1)
#[autodiff_forward(df4, Dual, Dual)]
fn f4(x: (f32, f32)) -> f32 {
x.0 * x.1
}

// CHECK: ; abi_handling::f5
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
// debug-NEXT: define internal float @_ZN12abi_handling2f517hf8d4ac4d2c2a3976E
// debug-SAME: (float %i.0, float %i.1)
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f517hf8d4ac4d2c2a3976E
// release-SAME: (float noundef %i.0, float noundef %i.1)
#[autodiff_forward(df5, Dual, Dual)]
fn f5(i: Input) -> f32 {
i.x + i.y
}

// CHECK: ; abi_handling::f6
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
// debug-NEXT: define internal float @_ZN12abi_handling2f617h5784b207bbb2483eE
// debug-SAME: (float %i.0, float %i.1)
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f617h5784b207bbb2483eE
// release-SAME: (float noundef %i.0, float noundef %i.1)
#[autodiff_forward(df6, Dual, Dual)]
fn f6(i: NestedInput) -> f32 {
i.x + i.y.z * i.y.z
}

// CHECK: ; abi_handling::f7
// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}}
// debug-NEXT: define internal float @_ZN12abi_handling2f717h44e3cff234e3b2d5E
// debug-SAME: (ptr align 4 %x.0, ptr align 4 %x.1)
// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f717h44e3cff234e3b2d5E
// release-SAME: (float %x.0.0.val, float %x.1.0.val)
#[autodiff_forward(df7, Dual, Dual)]
fn f7(x: (&f32, &f32)) -> f32 {
x.0 * x.1
}

// df1
// release: define internal fastcc { float, float }
// release-SAME: @fwddiffe_ZN12abi_handling2f117h536ac8081c1e4101E
// release-SAME: (float %x.0.val, float %x.4.val)
// release-NEXT: start:
// release-NEXT: %_0 = fadd float %x.0.val, %x.4.val
// release-NEXT: %0 = insertvalue { float, float } undef, float %_0, 0
// release-NEXT: %1 = insertvalue { float, float } %0, float 1.000000e+00, 1
// release-NEXT: ret { float, float } %1
// release-NEXT: }

// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f117h536ac8081c1e4101E
// debug-SAME: (ptr align 4 %x, ptr align 4 %"x'")
// debug-NEXT: start:
// debug-NEXT: %"'ipg" = getelementptr inbounds float, ptr %"x'", i64 0
// debug-NEXT: %0 = getelementptr inbounds nuw float, ptr %x, i64 0
// debug-NEXT: %"_2'ipl" = load float, ptr %"'ipg", align 4
// debug-NEXT: %_2 = load float, ptr %0, align 4
// debug-NEXT: %"'ipg2" = getelementptr inbounds float, ptr %"x'", i64 1
// debug-NEXT: %1 = getelementptr inbounds nuw float, ptr %x, i64 1
// debug-NEXT: %"_5'ipl" = load float, ptr %"'ipg2", align 4
// debug-NEXT: %_5 = load float, ptr %1, align 4
// debug-NEXT: %_0 = fadd float %_2, %_5
// debug-NEXT: %2 = fadd fast float %"_2'ipl", %"_5'ipl"
// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
// debug-NEXT: ret { float, float } %4
// debug-NEXT: }

// df2
// release: define internal fastcc { float, float }
// release-SAME: @diffe_ZN12abi_handling2f217h33732e9f83c91bc9E
// release-SAME: (float noundef %x)
// release-NEXT: invertstart:
// release-NEXT: %_0.i = fmul float %x, %x
// release-NEXT: %0 = insertvalue { float, float } undef, float %_0.i, 0
// release-NEXT: %1 = insertvalue { float, float } %0, float 0.000000e+00, 1
// release-NEXT: ret { float, float } %1
// release-NEXT: }

// debug: define internal { float, float } @diffe_ZN12abi_handling2f217h33732e9f83c91bc9E
// debug-SAME: (ptr %f, float %x, float %differeturn)
// debug-NEXT: start:
// debug-NEXT: %"x'de" = alloca float, align 4
// debug-NEXT: store float 0.000000e+00, ptr %"x'de", align 4
// debug-NEXT: %toreturn = alloca float, align 4
// debug-NEXT: %_0 = call float %f(float %x)
// debug-NEXT: store float %_0, ptr %toreturn, align 4
// debug-NEXT: br label %invertstart
// debug-EMPTY:
// debug-NEXT: invertstart: ; preds = %start
// debug-NEXT: %retreload = load float, ptr %toreturn, align 4
// debug-NEXT: %0 = load float, ptr %"x'de", align 4
// debug-NEXT: %1 = insertvalue { float, float } undef, float %retreload, 0
// debug-NEXT: %2 = insertvalue { float, float } %1, float %0, 1
// debug-NEXT: ret { float, float } %2
// debug-NEXT: }

// df3
// release: define internal fastcc { float, float }
// release-SAME: @fwddiffe_ZN12abi_handling2f317h9cd1fc602b0815a4E
// release-SAME: (float %x.0.val)
// release-NEXT: start:
// release-NEXT: %0 = insertvalue { float, float } undef, float %x.0.val, 0
// release-NEXT: %1 = insertvalue { float, float } %0, float 0x40099999A0000000, 1
// release-NEXT: ret { float, float } %1
// release-NEXT: }

// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f317h9cd1fc602b0815a4E
// debug-SAME: (ptr align 4 %x, ptr align 4 %"x'", ptr align 4 %y, ptr align 4 %"y'")
// debug-NEXT: start:
// debug-NEXT: %"_3'ipl" = load float, ptr %"x'", align 4
// debug-NEXT: %_3 = load float, ptr %x, align 4
// debug-NEXT: %"_4'ipl" = load float, ptr %"y'", align 4
// debug-NEXT: %_4 = load float, ptr %y, align 4
// debug-NEXT: %_0 = fmul float %_3, %_4
// debug-NEXT: %0 = fmul fast float %"_3'ipl", %_4
// debug-NEXT: %1 = fmul fast float %"_4'ipl", %_3
// debug-NEXT: %2 = fadd fast float %0, %1
// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
// debug-NEXT: ret { float, float } %4
// debug-NEXT: }

// df4
// release: define internal fastcc { float, float }
// release-SAME: @fwddiffe_ZN12abi_handling2f417h2f4a9a7492d91e9fE
// release-SAME: (float noundef %x.0, float %"x.0'")
// release-NEXT: start:
// release-NEXT: %0 = insertvalue { float, float } undef, float %x.0, 0
// release-NEXT: %1 = insertvalue { float, float } %0, float %"x.0'", 1
// release-NEXT: ret { float, float } %1
// release-NEXT: }

// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f417h2f4a9a7492d91e9fE
// debug-SAME: (float %x.0, float %"x.0'", float %x.1, float %"x.1'")
// debug-NEXT: start:
// debug-NEXT: %_0 = fmul float %x.0, %x.1
// debug-NEXT: %0 = fmul fast float %"x.0'", %x.1
// debug-NEXT: %1 = fmul fast float %"x.1'", %x.0
// debug-NEXT: %2 = fadd fast float %0, %1
// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
// debug-NEXT: ret { float, float } %4
// debug-NEXT: }

// df5
// release: define internal fastcc { float, float }
// release-SAME: @fwddiffe_ZN12abi_handling2f517hf8d4ac4d2c2a3976E
// release-SAME: (float noundef %i.0, float %"i.0'")
// release-NEXT: start:
// release-NEXT: %_0 = fadd float %i.0, 1.000000e+00
// release-NEXT: %0 = insertvalue { float, float } undef, float %_0, 0
// release-NEXT: %1 = insertvalue { float, float } %0, float %"i.0'", 1
// release-NEXT: ret { float, float } %1
// release-NEXT: }

// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f517hf8d4ac4d2c2a3976E
// debug-SAME: (float %i.0, float %"i.0'", float %i.1, float %"i.1'")
// debug-NEXT: start:
// debug-NEXT: %_0 = fadd float %i.0, %i.1
// debug-NEXT: %0 = fadd fast float %"i.0'", %"i.1'"
// debug-NEXT: %1 = insertvalue { float, float } undef, float %_0, 0
// debug-NEXT: %2 = insertvalue { float, float } %1, float %0, 1
// debug-NEXT: ret { float, float } %2
// debug-NEXT: }

// df6
// release: define internal fastcc { float, float }
// release-SAME: @fwddiffe_ZN12abi_handling2f617h5784b207bbb2483eE
// release-SAME: (float noundef %i.0, float %"i.0'", float noundef %i.1, float %"i.1'")
// release-NEXT: start:
// release-NEXT: %_3 = fmul float %i.1, %i.1
// release-NEXT: %0 = fadd fast float %"i.1'", %"i.1'"
// release-NEXT: %1 = fmul fast float %0, %i.1
// release-NEXT: %_0 = fadd float %i.0, %_3
// release-NEXT: %2 = fadd fast float %"i.0'", %1
// release-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0
// release-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
// release-NEXT: ret { float, float } %4
// release-NEXT: }

// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f617h5784b207bbb2483eE
// debug-SAME: (float %i.0, float %"i.0'", float %i.1, float %"i.1'")
// debug-NEXT: start:
// debug-NEXT: %_3 = fmul float %i.1, %i.1
// debug-NEXT: %0 = fmul fast float %"i.1'", %i.1
// debug-NEXT: %1 = fmul fast float %"i.1'", %i.1
// debug-NEXT: %2 = fadd fast float %0, %1
// debug-NEXT: %_0 = fadd float %i.0, %_3
// debug-NEXT: %3 = fadd fast float %"i.0'", %2
// debug-NEXT: %4 = insertvalue { float, float } undef, float %_0, 0
// debug-NEXT: %5 = insertvalue { float, float } %4, float %3, 1
// debug-NEXT: ret { float, float } %5
// debug-NEXT: }

// df7
// release: define internal fastcc { float, float }
// release-SAME: @fwddiffe_ZN12abi_handling2f717h44e3cff234e3b2d5E
// release-SAME: (float %x.0.0.val, float %"x.0'.0.val")
// release-NEXT: start:
// release-NEXT: %0 = insertvalue { float, float } undef, float %x.0.0.val, 0
// release-NEXT: %1 = insertvalue { float, float } %0, float %"x.0'.0.val", 1
// release-NEXT: ret { float, float } %1
// release-NEXT: }

// debug: define internal { float, float }
// debug-SAME: @fwddiffe_ZN12abi_handling2f717h44e3cff234e3b2d5E
// debug-SAME: (ptr align 4 %x.0, ptr align 4 %"x.0'", ptr align 4 %x.1, ptr align 4 %"x.1'")
// debug-NEXT: start:
// debug-NEXT: %0 = call fast { float, float } @"fwddiffe_ZN49_{{.*}}"
// debug-NEXT: %1 = extractvalue { float, float } %0, 0
// debug-NEXT: %2 = extractvalue { float, float } %0, 1
// debug-NEXT: %3 = insertvalue { float, float } undef, float %1, 0
// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1
// debug-NEXT: ret { float, float } %4
// debug-NEXT: }

fn main() {
let x = std::hint::black_box(2.0);
let y = std::hint::black_box(3.0);
let z = std::hint::black_box(4.0);
static Y: f32 = std::hint::black_box(3.2);

let in_f1 = [x, y];
dbg!(f1(&in_f1));
let res_f1 = df1(&in_f1, &[1.0, 0.0]);
dbg!(res_f1);

dbg!(f2(square, x));
let res_f2 = df2(square, x, 1.0);
dbg!(res_f2);

dbg!(f3(&x, &Y));
let res_f3 = df3(&x, &Y, &1.0, &0.0);
dbg!(res_f3);

let in_f4 = (x, y);
dbg!(f4(in_f4));
let res_f4 = df4(in_f4, (1.0, 0.0));
dbg!(res_f4);

let in_f5 = Input { x, y };
dbg!(f5(in_f5));
let res_f5 = df5(in_f5, Input { x: 1.0, y: 0.0 });
dbg!(res_f5);

let in_f6 = NestedInput { x, y: Wrapper { z: y } };
dbg!(f6(in_f6));
let res_f6 = df6(in_f6, NestedInput { x, y: Wrapper { z } });
dbg!(res_f6);

let in_f7 = (&x, &y);
dbg!(f7(in_f7));
let res_f7 = df7(in_f7, (&1.0, &0.0));
dbg!(res_f7);
}
Loading