|
| 1 | +# CLAUDE.md |
| 2 | + |
| 3 | +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. |
| 4 | + |
| 5 | +## Commands |
| 6 | + |
| 7 | +### Install |
| 8 | +```bash |
| 9 | +pip install -e ".[dev]" |
| 10 | +``` |
| 11 | + |
| 12 | +### Run Tests |
| 13 | +```bash |
| 14 | +# All tests |
| 15 | +python -m pytest test_autoarray/ |
| 16 | + |
| 17 | +# Single test file |
| 18 | +python -m pytest test_autoarray/structures/test_arrays.py |
| 19 | + |
| 20 | +# With output |
| 21 | +python -m pytest test_autoarray/structures/test_arrays.py -s |
| 22 | +``` |
| 23 | + |
| 24 | +### Formatting |
| 25 | +```bash |
| 26 | +black autoarray/ |
| 27 | +``` |
| 28 | + |
| 29 | +## Architecture |
| 30 | + |
| 31 | +**PyAutoArray** is the low-level data structures and numerical utilities package for the PyAuto ecosystem. It provides: |
| 32 | +- **Grid and array structures** — uniform and irregular 2D grids, arrays, vector fields |
| 33 | +- **Masks** — 1D and 2D masks that define which pixels are active |
| 34 | +- **Datasets** — imaging and interferometer dataset containers |
| 35 | +- **Inversions / pixelizations** — sparse linear algebra for source reconstruction |
| 36 | +- **Decorators** — input/output homogenisation for grid-consuming functions |
| 37 | + |
| 38 | +## Core Data Structures |
| 39 | + |
| 40 | +All data structures inherit from `AbstractNDArray` (`abstract_ndarray.py`). Key subclasses: |
| 41 | + |
| 42 | +| Class | Description | |
| 43 | +|---|---| |
| 44 | +| `Array2D` | Uniform 2D array tied to a `Mask2D` | |
| 45 | +| `ArrayIrregular` | Unmasked 1D collection of values | |
| 46 | +| `Grid2D` | Uniform (y,x) coordinate grid tied to a `Mask2D` | |
| 47 | +| `Grid2DIrregular` | Irregular (y,x) coordinate collection | |
| 48 | +| `VectorYX2D` | Uniform 2D vector field | |
| 49 | +| `VectorYX2DIrregular` | Irregular vector field | |
| 50 | + |
| 51 | +`AbstractNDArray` provides arithmetic operators (`__add__`, `__sub__`, `__rsub__`, etc.), all decorated with `@to_new_array` and `@unwrap_array` so that operations between autoarray objects and raw scalars/arrays work naturally and return a new autoarray of the same type. |
| 52 | + |
| 53 | +The `.array` property returns the raw underlying `numpy.ndarray` or `jax.Array`: |
| 54 | +```python |
| 55 | +arr = aa.ArrayIrregular(values=[1.0, 2.0]) |
| 56 | +arr.array # raw numpy array |
| 57 | +arr._array # same, internal attribute |
| 58 | +``` |
| 59 | + |
| 60 | +The constructor unwraps nested autoarray objects automatically: |
| 61 | +```python |
| 62 | +# while isinstance(array, AbstractNDArray): array = array.array |
| 63 | +``` |
| 64 | + |
| 65 | +## Decorator System |
| 66 | + |
| 67 | +`autoarray/structures/decorators/` contains three output-wrapping decorators used on all grid-consuming functions. They ensure that the **type of the output structure matches the type of the input grid**: |
| 68 | + |
| 69 | +| Decorator | Grid2D input | Grid2DIrregular input | |
| 70 | +|---|---|---| |
| 71 | +| `@aa.grid_dec.to_array` | `Array2D` | `ArrayIrregular` | |
| 72 | +| `@aa.grid_dec.to_grid` | `Grid2D` | `Grid2DIrregular` | |
| 73 | +| `@aa.grid_dec.to_vector_yx` | `VectorYX2D` | `VectorYX2DIrregular` | |
| 74 | + |
| 75 | +### How the decorators work |
| 76 | + |
| 77 | +All three share `AbstractMaker` (`decorators/abstract.py`). The decorator: |
| 78 | +1. Wraps the function in a `wrapper(obj, grid, xp=np, *args, **kwargs)` signature |
| 79 | +2. Instantiates the relevant `*Maker` class with the function, object, grid, and `xp` |
| 80 | +3. `AbstractMaker.result` checks the grid type and calls the appropriate `via_grid_2d` / `via_grid_2d_irr` method to wrap the raw result |
| 81 | + |
| 82 | +The function body receives the grid as-is and **must return a raw array** (not an autoarray wrapper). The decorator does the wrapping: |
| 83 | + |
| 84 | +```python |
| 85 | +@aa.grid_dec.to_array |
| 86 | +def convergence_2d_from(self, grid, xp=np, **kwargs): |
| 87 | + # grid is Grid2D or Grid2DIrregular — access raw values via grid.array[:,0] |
| 88 | + y = grid.array[:, 0] |
| 89 | + x = grid.array[:, 1] |
| 90 | + return xp.sqrt(y**2 + x**2) # return raw array; decorator wraps it |
| 91 | +``` |
| 92 | + |
| 93 | +`AbstractMaker` also stores `use_jax = xp is not np` and exposes `_xp` (either `jnp` or `np`), but the wrapping step always runs regardless of `xp`. Autoarray types are **not registered as JAX pytrees**, so they cannot be directly returned from inside a `jax.jit` trace (see JAX section below). |
| 94 | + |
| 95 | +### Accessing grid coordinates inside a decorated function |
| 96 | + |
| 97 | +Inside a decorated function body, access the raw underlying array with `.array`: |
| 98 | + |
| 99 | +```python |
| 100 | +# Correct — works for both numpy and jax backends |
| 101 | +y = grid.array[:, 0] |
| 102 | +x = grid.array[:, 1] |
| 103 | + |
| 104 | +# Also correct for simple slicing (returns raw array via __getitem__) |
| 105 | +y = grid[:, 0] |
| 106 | +x = grid[:, 1] |
| 107 | +``` |
| 108 | + |
| 109 | +The `@transform` decorator (also in `decorators/`) shifts and rotates the input grid to the profile's reference frame before passing it to the function. It calls `obj.transformed_to_reference_frame_grid_from(grid, xp)` (decorated with `@to_grid`) and passes the result as the `grid` argument. After transformation the grid is still an autoarray object; `.array` still works. |
| 110 | + |
| 111 | +### Decorator stacking order |
| 112 | + |
| 113 | +Decorators are applied bottom-up (innermost first). The canonical order for mass/light profile methods is: |
| 114 | + |
| 115 | +```python |
| 116 | +@aa.grid_dec.to_array # outermost: wraps output |
| 117 | +@aa.grid_dec.transform # innermost: transforms grid input |
| 118 | +def convergence_2d_from(self, grid, xp=np, **kwargs): |
| 119 | + ... |
| 120 | +``` |
| 121 | + |
| 122 | +## JAX Support |
| 123 | + |
| 124 | +The `xp` parameter pattern is the single point of control: |
| 125 | +- `xp=np` (default) — pure NumPy path |
| 126 | +- `xp=jnp` — JAX path; `jax` / `jax.numpy` are only imported locally |
| 127 | + |
| 128 | +### Why autoarray types cannot be returned from `jax.jit` |
| 129 | + |
| 130 | +`AbstractNDArray` subclasses (`Array2D`, `ArrayIrregular`, `VectorYX2DIrregular`, etc.) are **not registered as JAX pytrees**. The `instance_flatten` / `instance_unflatten` class methods are defined on `AbstractNDArray` but are never passed to `jax.tree_util.register_pytree_node`. As a result: |
| 131 | + |
| 132 | +- Constructing an autoarray wrapper **inside** a JIT trace is fine (Python-level code runs normally during tracing) |
| 133 | +- **Returning** an autoarray wrapper as the output of a `jax.jit`-compiled function **fails** with `TypeError: ... is not a valid JAX type` |
| 134 | + |
| 135 | +### The `if xp is np:` guard pattern |
| 136 | + |
| 137 | +Functions that are called directly inside `jax.jit` (i.e., as the outermost call in the lambda) must not return autoarray wrappers on the JAX path. The correct pattern is: |
| 138 | + |
| 139 | +```python |
| 140 | +def convergence_2d_via_hessian_from(self, grid, xp=np): |
| 141 | + hessian_yy, hessian_xx = ... |
| 142 | + convergence = 0.5 * (hessian_yy + hessian_xx) |
| 143 | + |
| 144 | + if xp is np: |
| 145 | + return aa.ArrayIrregular(values=convergence) # numpy: wrapped |
| 146 | + return convergence # jax: raw jax.Array |
| 147 | +``` |
| 148 | + |
| 149 | +This pattern is used in `autogalaxy/operate/lens_calc.py` for all `LensCalc` methods that are called inside `jax.jit`. It does **not** affect decorated helper functions (like `deflections_yx_2d_from`) because those are called as intermediate steps — their autoarray wrappers are consumed by downstream Python code, never returned as JIT outputs. |
| 150 | + |
| 151 | +## Line Endings — Always Unix (LF) |
| 152 | + |
| 153 | +All files **must use Unix line endings (LF, `\n`)**. Never write `\r\n` line endings. |
0 commit comments