Skip to content

Commit 4c51723

Browse files
Jammy2211claude
authored andcommitted
add CLAUDE.md documenting decorator system and JAX jit boundary rules
Documents AbstractMaker, the to_array/to_grid/to_vector_yx decorators, the .array property pattern for accessing raw grid data inside function bodies, and why autoarray types cannot be returned from jax.jit. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 5c4f7bb commit 4c51723

File tree

1 file changed

+153
-0
lines changed

1 file changed

+153
-0
lines changed

CLAUDE.md

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# CLAUDE.md
2+
3+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4+
5+
## Commands
6+
7+
### Install
8+
```bash
9+
pip install -e ".[dev]"
10+
```
11+
12+
### Run Tests
13+
```bash
14+
# All tests
15+
python -m pytest test_autoarray/
16+
17+
# Single test file
18+
python -m pytest test_autoarray/structures/test_arrays.py
19+
20+
# With output
21+
python -m pytest test_autoarray/structures/test_arrays.py -s
22+
```
23+
24+
### Formatting
25+
```bash
26+
black autoarray/
27+
```
28+
29+
## Architecture
30+
31+
**PyAutoArray** is the low-level data structures and numerical utilities package for the PyAuto ecosystem. It provides:
32+
- **Grid and array structures** — uniform and irregular 2D grids, arrays, vector fields
33+
- **Masks** — 1D and 2D masks that define which pixels are active
34+
- **Datasets** — imaging and interferometer dataset containers
35+
- **Inversions / pixelizations** — sparse linear algebra for source reconstruction
36+
- **Decorators** — input/output homogenisation for grid-consuming functions
37+
38+
## Core Data Structures
39+
40+
All data structures inherit from `AbstractNDArray` (`abstract_ndarray.py`). Key subclasses:
41+
42+
| Class | Description |
43+
|---|---|
44+
| `Array2D` | Uniform 2D array tied to a `Mask2D` |
45+
| `ArrayIrregular` | Unmasked 1D collection of values |
46+
| `Grid2D` | Uniform (y,x) coordinate grid tied to a `Mask2D` |
47+
| `Grid2DIrregular` | Irregular (y,x) coordinate collection |
48+
| `VectorYX2D` | Uniform 2D vector field |
49+
| `VectorYX2DIrregular` | Irregular vector field |
50+
51+
`AbstractNDArray` provides arithmetic operators (`__add__`, `__sub__`, `__rsub__`, etc.), all decorated with `@to_new_array` and `@unwrap_array` so that operations between autoarray objects and raw scalars/arrays work naturally and return a new autoarray of the same type.
52+
53+
The `.array` property returns the raw underlying `numpy.ndarray` or `jax.Array`:
54+
```python
55+
arr = aa.ArrayIrregular(values=[1.0, 2.0])
56+
arr.array # raw numpy array
57+
arr._array # same, internal attribute
58+
```
59+
60+
The constructor unwraps nested autoarray objects automatically:
61+
```python
62+
# while isinstance(array, AbstractNDArray): array = array.array
63+
```
64+
65+
## Decorator System
66+
67+
`autoarray/structures/decorators/` contains three output-wrapping decorators used on all grid-consuming functions. They ensure that the **type of the output structure matches the type of the input grid**:
68+
69+
| Decorator | Grid2D input | Grid2DIrregular input |
70+
|---|---|---|
71+
| `@aa.grid_dec.to_array` | `Array2D` | `ArrayIrregular` |
72+
| `@aa.grid_dec.to_grid` | `Grid2D` | `Grid2DIrregular` |
73+
| `@aa.grid_dec.to_vector_yx` | `VectorYX2D` | `VectorYX2DIrregular` |
74+
75+
### How the decorators work
76+
77+
All three share `AbstractMaker` (`decorators/abstract.py`). The decorator:
78+
1. Wraps the function in a `wrapper(obj, grid, xp=np, *args, **kwargs)` signature
79+
2. Instantiates the relevant `*Maker` class with the function, object, grid, and `xp`
80+
3. `AbstractMaker.result` checks the grid type and calls the appropriate `via_grid_2d` / `via_grid_2d_irr` method to wrap the raw result
81+
82+
The function body receives the grid as-is and **must return a raw array** (not an autoarray wrapper). The decorator does the wrapping:
83+
84+
```python
85+
@aa.grid_dec.to_array
86+
def convergence_2d_from(self, grid, xp=np, **kwargs):
87+
# grid is Grid2D or Grid2DIrregular — access raw values via grid.array[:,0]
88+
y = grid.array[:, 0]
89+
x = grid.array[:, 1]
90+
return xp.sqrt(y**2 + x**2) # return raw array; decorator wraps it
91+
```
92+
93+
`AbstractMaker` also stores `use_jax = xp is not np` and exposes `_xp` (either `jnp` or `np`), but the wrapping step always runs regardless of `xp`. Autoarray types are **not registered as JAX pytrees**, so they cannot be directly returned from inside a `jax.jit` trace (see JAX section below).
94+
95+
### Accessing grid coordinates inside a decorated function
96+
97+
Inside a decorated function body, access the raw underlying array with `.array`:
98+
99+
```python
100+
# Correct — works for both numpy and jax backends
101+
y = grid.array[:, 0]
102+
x = grid.array[:, 1]
103+
104+
# Also correct for simple slicing (returns raw array via __getitem__)
105+
y = grid[:, 0]
106+
x = grid[:, 1]
107+
```
108+
109+
The `@transform` decorator (also in `decorators/`) shifts and rotates the input grid to the profile's reference frame before passing it to the function. It calls `obj.transformed_to_reference_frame_grid_from(grid, xp)` (decorated with `@to_grid`) and passes the result as the `grid` argument. After transformation the grid is still an autoarray object; `.array` still works.
110+
111+
### Decorator stacking order
112+
113+
Decorators are applied bottom-up (innermost first). The canonical order for mass/light profile methods is:
114+
115+
```python
116+
@aa.grid_dec.to_array # outermost: wraps output
117+
@aa.grid_dec.transform # innermost: transforms grid input
118+
def convergence_2d_from(self, grid, xp=np, **kwargs):
119+
...
120+
```
121+
122+
## JAX Support
123+
124+
The `xp` parameter pattern is the single point of control:
125+
- `xp=np` (default) — pure NumPy path
126+
- `xp=jnp` — JAX path; `jax` / `jax.numpy` are only imported locally
127+
128+
### Why autoarray types cannot be returned from `jax.jit`
129+
130+
`AbstractNDArray` subclasses (`Array2D`, `ArrayIrregular`, `VectorYX2DIrregular`, etc.) are **not registered as JAX pytrees**. The `instance_flatten` / `instance_unflatten` class methods are defined on `AbstractNDArray` but are never passed to `jax.tree_util.register_pytree_node`. As a result:
131+
132+
- Constructing an autoarray wrapper **inside** a JIT trace is fine (Python-level code runs normally during tracing)
133+
- **Returning** an autoarray wrapper as the output of a `jax.jit`-compiled function **fails** with `TypeError: ... is not a valid JAX type`
134+
135+
### The `if xp is np:` guard pattern
136+
137+
Functions that are called directly inside `jax.jit` (i.e., as the outermost call in the lambda) must not return autoarray wrappers on the JAX path. The correct pattern is:
138+
139+
```python
140+
def convergence_2d_via_hessian_from(self, grid, xp=np):
141+
hessian_yy, hessian_xx = ...
142+
convergence = 0.5 * (hessian_yy + hessian_xx)
143+
144+
if xp is np:
145+
return aa.ArrayIrregular(values=convergence) # numpy: wrapped
146+
return convergence # jax: raw jax.Array
147+
```
148+
149+
This pattern is used in `autogalaxy/operate/lens_calc.py` for all `LensCalc` methods that are called inside `jax.jit`. It does **not** affect decorated helper functions (like `deflections_yx_2d_from`) because those are called as intermediate steps — their autoarray wrappers are consumed by downstream Python code, never returned as JIT outputs.
150+
151+
## Line Endings — Always Unix (LF)
152+
153+
All files **must use Unix line endings (LF, `\n`)**. Never write `\r\n` line endings.

0 commit comments

Comments
 (0)