Skip to content

Commit c3dd859

Browse files
committed
autodiff: add TypeTree support for arrays
1 parent d447882 commit c3dd859

File tree

4 files changed

+69
-1
lines changed

4 files changed

+69
-1
lines changed

compiler/rustc_middle/src/ty/mod.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2292,6 +2292,46 @@ pub fn typetree_from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> TypeTree {
22922292
}]);
22932293
}
22942294

2295-
// FIXME(KMJ-007): Handle arrays, slices, structs, and other complex types
2295+
if ty.is_array() {
2296+
if let ty::Array(element_ty, len_const) = ty.kind() {
2297+
let len = len_const.try_to_target_usize(tcx).unwrap_or(0);
2298+
if len == 0 {
2299+
return TypeTree::new();
2300+
}
2301+
2302+
let element_tree = typetree_from_ty(tcx, *element_ty);
2303+
2304+
let element_layout = tcx
2305+
.layout_of(ty::TypingEnv::fully_monomorphized().as_query_input(*element_ty))
2306+
.ok()
2307+
.map(|layout| layout.size.bytes_usize())
2308+
.unwrap_or(0);
2309+
2310+
if element_layout == 0 {
2311+
return TypeTree::new();
2312+
}
2313+
2314+
let mut types = Vec::new();
2315+
for i in 0..len {
2316+
let base_offset = (i as usize * element_layout) as isize;
2317+
2318+
for elem_type in &element_tree.0 {
2319+
types.push(Type {
2320+
offset: if elem_type.offset == -1 {
2321+
base_offset
2322+
} else {
2323+
base_offset + elem_type.offset
2324+
},
2325+
size: elem_type.size,
2326+
kind: elem_type.kind,
2327+
child: elem_type.child.clone(),
2328+
});
2329+
}
2330+
}
2331+
2332+
return TypeTree(types);
2333+
}
2334+
}
2335+
22962336
TypeTree::new()
22972337
}
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
; Check that array TypeTree metadata is correctly generated
2+
; Should show Float@double at each array element offset (0, 8, 16, 24, 32 bytes)
3+
4+
CHECK: define{{.*}}"enzyme_type"="{[]:Float@double}"{{.*}}@test_array{{.*}}"enzyme_type"="{[]:Pointer}"
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
//@ needs-enzyme
2+
//@ ignore-cross-compile
3+
4+
use run_make_support::{llvm_filecheck, rfs, rustc};
5+
6+
fn main() {
7+
rustc().input("test.rs").arg("-Zautodiff=Enable").emit("llvm-ir").run();
8+
llvm_filecheck().patterns("array.check").stdin_buf(rfs::read("test.ll")).run();
9+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#![feature(autodiff)]
2+
3+
use std::autodiff::autodiff_reverse;
4+
5+
#[autodiff_reverse(d_test, Duplicated, Active)]
6+
#[no_mangle]
7+
fn test_array(arr: &[f64; 5]) -> f64 {
8+
arr[0] + arr[1] + arr[2] + arr[3] + arr[4]
9+
}
10+
11+
fn main() {
12+
let arr = [1.0, 2.0, 3.0, 4.0, 5.0];
13+
let mut d_arr = [0.0; 5];
14+
let _result = d_test(&arr, &mut d_arr, 1.0);
15+
}

0 commit comments

Comments
 (0)