|
| 1 | +# TypeTrees for Autodiff |
| 2 | + |
| 3 | +## What are TypeTrees? |
| 4 | +Memory layout descriptors for Enzyme. Tell Enzyme exactly how types are structured in memory so it can compute derivatives efficiently. |
| 5 | + |
| 6 | +## Structure |
| 7 | +```rust |
| 8 | +TypeTree(Vec<Type>) |
| 9 | + |
| 10 | +Type { |
| 11 | + offset: isize, // byte offset (-1 = everywhere) |
| 12 | + size: usize, // size in bytes |
| 13 | + kind: Kind, // Float, Integer, Pointer, etc. |
| 14 | + child: TypeTree // nested structure |
| 15 | +} |
| 16 | +``` |
| 17 | + |
| 18 | +## Example: `fn compute(x: &f32, data: &[f32]) -> f32` |
| 19 | + |
| 20 | +**Input 0: `x: &f32`** |
| 21 | +```rust |
| 22 | +TypeTree(vec![Type { |
| 23 | + offset: -1, size: 8, kind: Pointer, |
| 24 | + child: TypeTree(vec![Type { |
| 25 | + offset: 0, size: 4, kind: Float, // Single value: use offset 0 |
| 26 | + child: TypeTree::new() |
| 27 | + }]) |
| 28 | +}]) |
| 29 | +``` |
| 30 | + |
| 31 | +**Input 1: `data: &[f32]`** |
| 32 | +```rust |
| 33 | +TypeTree(vec![Type { |
| 34 | + offset: -1, size: 8, kind: Pointer, |
| 35 | + child: TypeTree(vec![Type { |
| 36 | + offset: -1, size: 4, kind: Float, // -1 = all elements |
| 37 | + child: TypeTree::new() |
| 38 | + }]) |
| 39 | +}]) |
| 40 | +``` |
| 41 | + |
| 42 | +**Output: `f32`** |
| 43 | +```rust |
| 44 | +TypeTree(vec![Type { |
| 45 | + offset: 0, size: 4, kind: Float, // Single scalar: use offset 0 |
| 46 | + child: TypeTree::new() |
| 47 | +}]) |
| 48 | +``` |
| 49 | + |
| 50 | +## Why Needed? |
| 51 | +- Enzyme can't deduce complex type layouts from LLVM IR |
| 52 | +- Prevents slow memory pattern analysis |
| 53 | +- Enables correct derivative computation for nested structures |
| 54 | +- Tells Enzyme which bytes are differentiable vs metadata |
| 55 | + |
| 56 | +## What Enzyme Does With This Information: |
| 57 | + |
| 58 | +Without TypeTrees: |
| 59 | +```llvm |
| 60 | +; Enzyme sees generic LLVM IR: |
| 61 | +define float @distance(ptr %p1, ptr %p2) { |
| 62 | +; Has to guess what these pointers point to |
| 63 | +; Slow analysis of all memory operations |
| 64 | +; May miss optimization opportunities |
| 65 | +} |
| 66 | +``` |
| 67 | + |
| 68 | +With TypeTrees: |
| 69 | +```llvm |
| 70 | +define "enzyme_type"="{[-1]:Float@float}" float @distance( |
| 71 | + ptr "enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}" %p1, |
| 72 | + ptr "enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float}" %p2 |
| 73 | +) { |
| 74 | +; Enzyme knows exact type layout |
| 75 | +; Can generate efficient derivative code directly |
| 76 | +} |
| 77 | +``` |
| 78 | + |
| 79 | +# TypeTrees - Offset and -1 Explained |
| 80 | + |
| 81 | +## Type Structure |
| 82 | + |
| 83 | +```rust |
| 84 | +Type { |
| 85 | + offset: isize, // WHERE this type starts |
| 86 | + size: usize, // HOW BIG this type is |
| 87 | + kind: Kind, // WHAT KIND of data (Float, Int, Pointer) |
| 88 | + child: TypeTree // WHAT'S INSIDE (for pointers/containers) |
| 89 | +} |
| 90 | +``` |
| 91 | + |
| 92 | +## Offset Values |
| 93 | + |
| 94 | +### Regular Offset (0, 4, 8, etc.) |
| 95 | +**Specific byte position within a structure** |
| 96 | + |
| 97 | +```rust |
| 98 | +struct Point { |
| 99 | + x: f32, // offset 0, size 4 |
| 100 | + y: f32, // offset 4, size 4 |
| 101 | + id: i32, // offset 8, size 4 |
| 102 | +} |
| 103 | +``` |
| 104 | + |
| 105 | +TypeTree for `&Point` (internal representation): |
| 106 | +```rust |
| 107 | +TypeTree(vec![ |
| 108 | + Type { offset: 0, size: 4, kind: Float }, // x at byte 0 |
| 109 | + Type { offset: 4, size: 4, kind: Float }, // y at byte 4 |
| 110 | + Type { offset: 8, size: 4, kind: Integer } // id at byte 8 |
| 111 | +]) |
| 112 | +``` |
| 113 | + |
| 114 | +Generates LLVM |
| 115 | +```llvm |
| 116 | +"enzyme_type"="{[-1]:Pointer, [-1,0]:Float@float, [-1,4]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer}" |
| 117 | +``` |
| 118 | + |
| 119 | +### Offset -1 (Special: "Everywhere") |
| 120 | +**Means "this pattern repeats for ALL elements"** |
| 121 | + |
| 122 | +#### Example 1: Direct Array `[f32; 100]` (no pointer indirection) |
| 123 | +```rust |
| 124 | +TypeTree(vec![Type { |
| 125 | + offset: -1, // ALL positions |
| 126 | + size: 4, // each f32 is 4 bytes |
| 127 | + kind: Float, // every element is float |
| 128 | +}]) |
| 129 | +``` |
| 130 | + |
| 131 | +Generates LLVM: `"enzyme_type"="{[-1]:Float@float}"` |
| 132 | + |
| 133 | +#### Example 1b: Array Reference `&[f32; 100]` (with pointer indirection) |
| 134 | +```rust |
| 135 | +TypeTree(vec![Type { |
| 136 | + offset: -1, size: 8, kind: Pointer, |
| 137 | + child: TypeTree(vec![Type { |
| 138 | + offset: -1, // ALL array elements |
| 139 | + size: 4, // each f32 is 4 bytes |
| 140 | + kind: Float, // every element is float |
| 141 | + }]) |
| 142 | +}]) |
| 143 | +``` |
| 144 | + |
| 145 | +Generates LLVM: `"enzyme_type"="{[-1]:Pointer, [-1,-1]:Float@float}"` |
| 146 | + |
| 147 | +Instead of listing 100 separate Types with offsets `0,4,8,12...396` |
| 148 | + |
| 149 | +#### Example 2: Slice `&[i32]` |
| 150 | +```rust |
| 151 | +// Pointer to slice data |
| 152 | +TypeTree(vec![Type { |
| 153 | + offset: -1, size: 8, kind: Pointer, |
| 154 | + child: TypeTree(vec![Type { |
| 155 | + offset: -1, // ALL slice elements |
| 156 | + size: 4, // each i32 is 4 bytes |
| 157 | + kind: Integer |
| 158 | + }]) |
| 159 | +}]) |
| 160 | +``` |
| 161 | + |
| 162 | +Generates LLVM: `"enzyme_type"="{[-1]:Pointer, [-1,-1]:Integer}"` |
| 163 | + |
| 164 | +#### Example 3: Mixed Structure |
| 165 | +```rust |
| 166 | +struct Container { |
| 167 | + header: i64, // offset 0 |
| 168 | + data: [f32; 1000], // offset 8, but elements use -1 |
| 169 | +} |
| 170 | +``` |
| 171 | + |
| 172 | +```rust |
| 173 | +TypeTree(vec![ |
| 174 | + Type { offset: 0, size: 8, kind: Integer }, // header |
| 175 | + Type { offset: 8, size: 4000, kind: Pointer, |
| 176 | + child: TypeTree(vec![Type { |
| 177 | + offset: -1, size: 4, kind: Float // ALL array elements |
| 178 | + }]) |
| 179 | + } |
| 180 | +]) |
| 181 | +``` |
| 182 | + |
| 183 | +## Key Distinction: Single Values vs Arrays |
| 184 | + |
| 185 | +**Single Values** use offset `0` for precision: |
| 186 | +- `&f32` has exactly one f32 value at offset 0 |
| 187 | +- More precise than using -1 ("everywhere") |
| 188 | +- Generates: `{[-1]:Pointer, [-1,0]:Float@float}` |
| 189 | + |
| 190 | +**Arrays** use offset `-1` for efficiency: |
| 191 | +- `&[f32; 100]` has the same pattern repeated 100 times |
| 192 | +- Using -1 avoids listing 100 separate offsets |
| 193 | +- Generates: `{[-1]:Pointer, [-1,-1]:Float@float}` |
0 commit comments