Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,32 @@ Each dataset type has an `Analysis*` class that implements `log_likelihood_funct

These inherit from `AnalysisDataset` → `Analysis` (in `analysis/analysis/`), which inherits `af.Analysis`. The `log_likelihood_function` builds a `Fit*` object from the `af.ModelInstance` and returns its `figure_of_merit`.

### Decorator System (from autoarray)

Profile methods that consume a grid and return an array, grid, or vector use decorators from `autoarray.structures.decorators`. These ensure the **output type matches the input grid type**:

| Decorator | `Grid2D` input | `Grid2DIrregular` input |
|---|---|---|
| `@aa.grid_dec.to_array` | `Array2D` | `ArrayIrregular` |
| `@aa.grid_dec.to_grid` | `Grid2D` | `Grid2DIrregular` |
| `@aa.grid_dec.to_vector_yx` | `VectorYX2D` | `VectorYX2DIrregular` |

The `@aa.grid_dec.transform` decorator (always stacked below the output decorator) shifts and rotates the grid to the profile's reference frame before passing it to the function body.

The canonical stacking order is:
```python
@aa.grid_dec.to_array # outermost: wraps output
@aa.grid_dec.transform # innermost: transforms grid
def convergence_2d_from(self, grid, xp=np, **kwargs):
y = grid.array[:, 0] # use .array to get raw numpy/jax array
x = grid.array[:, 1]
return ... # return raw array; decorator wraps it
```

**Key rule**: the function body must return a **raw array** (not an autoarray). The decorator handles wrapping. Access grid coordinates via `grid.array[:, 0]` / `grid.array[:, 1]` (not `grid[:, 0]`), because after `@transform` the grid is still an autoarray object and `.array` is the safe way to extract the underlying data for both numpy and jax backends.

See PyAutoArray's `CLAUDE.md` for full details on the decorator internals.

### JAX Support

The codebase is designed so that **NumPy is the default everywhere and JAX is opt-in**. JAX is never imported at module level — it is only imported locally inside functions when explicitly requested.
Expand All @@ -100,6 +126,28 @@ When adding a new function that should support JAX:
4. Add a JAX implementation in the guarded branch (e.g. `jax.jacfwd`, `jnp.vectorize`)
5. Verify correctness by comparing both paths in `autogalaxy_workspace_test/scripts/`

### JAX and autoarray wrappers at the `jax.jit` boundary

Autoarray types (`Array2D`, `ArrayIrregular`, `VectorYX2DIrregular`, etc.) are **not registered as JAX pytrees**. This means:

- Constructing them **inside** a JIT trace is fine (Python code runs normally during tracing)
- **Returning** them as the output of a `jax.jit`-compiled function **fails** with `TypeError: ... is not a valid JAX type`

Functions decorated with `@aa.grid_dec.to_array` / `@to_vector_yx` wrap their return value in an autoarray type. This wrapping is safe for intermediate calls (the autoarray object is consumed by downstream Python code). However, if such a function is the **outermost call** inside a `jax.jit` lambda, its return value will fail at the JIT boundary.

The solution is the **`if xp is np:` guard** in the function body:

```python
def convergence_2d_via_hessian_from(self, grid, xp=np):
convergence = 0.5 * (hessian_yy + hessian_xx)

if xp is np:
return aa.ArrayIrregular(values=convergence) # numpy: wrapped
return convergence # jax: raw jax.Array
```

This pattern is applied throughout `autogalaxy/operate/lens_calc.py`. Functions that are only ever called as intermediate steps (e.g. `deflections_yx_2d_from`) do NOT need this guard — their autoarray wrappers are never the JIT output.

### Linear Light Profiles & Inversions

`LightProfileLinear` subclasses do not take an `intensity` parameter—it is solved via a linear inversion (provided by `autoarray`). The `GalaxiesToInversion` class (`galaxy/to_inversion.py`) handles converting galaxies with linear profiles or pixelizations into the inversion objects needed by `autoarray`.
Expand Down Expand Up @@ -147,3 +195,22 @@ When importing `autogalaxy as ag`:
- `ag.ps.*` – point sources
- `ag.Galaxy`, `ag.Galaxies`
- `ag.FitImaging`, `ag.AnalysisImaging`, `ag.SimulatorImaging`

## Line Endings — Always Unix (LF)

All files in this project **must use Unix line endings (LF, `\n`)**. Windows/DOS line endings (CRLF, `\r\n`) will break Python files on HPC systems.

**When writing or editing any file**, always produce Unix line endings. Never write `\r\n` line endings.

After creating or copying files, verify and convert if needed:

```bash
# Check for DOS line endings
file autogalaxy/galaxy/galaxy.py # should say "ASCII text", not "CRLF"

# Convert all Python files in the project
find . -type f -name "*.py" | xargs dos2unix
```

Prefer simple shell commands.
Avoid chaining with && or pipes.
Comment on lines +215 to +216
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The guidance says to avoid chaining with pipes, but the example immediately above uses a pipe (find ... | xargs ...). Either adjust the recommendation (e.g. avoid complex pipelines) or update the example so the instructions are internally consistent.

Suggested change
Prefer simple shell commands.
Avoid chaining with && or pipes.
Prefer simple, readable shell commands.
Avoid complex command chains with multiple && operators or many pipes; a single straightforward pipe (as above) is fine when it improves clarity.

Copilot uses AI. Check for mistakes.
30 changes: 24 additions & 6 deletions autogalaxy/operate/lens_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,12 @@ def tangential_eigen_value_from(self, grid, xp=np) -> aa.Array2D:
convergence = self.convergence_2d_via_hessian_from(grid=grid, xp=xp)
shear_yx = self.shear_yx_2d_via_hessian_from(grid=grid, xp=xp)

return aa.Array2D(values=1 - convergence - shear_yx.magnitudes, mask=grid.mask)
if xp is np:
return aa.Array2D(
values=1 - convergence - shear_yx.magnitudes, mask=grid.mask
)
magnitudes = xp.sqrt(shear_yx[:, 0] ** 2 + shear_yx[:, 1] ** 2)
return 1 - convergence - magnitudes
Comment on lines +250 to +255
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is specifically to avoid TypeError at the jax.jit boundary, but there’s no regression test exercising these methods under jax.jit. Since JAX is already available in the test suite (test_autogalaxy/conftest.py imports jax.numpy), consider adding a small unit test that jits a lambda calling each updated method with xp=jnp and asserts it runs/returns a JAX array.

Copilot uses AI. Check for mistakes.
Comment on lines +250 to +255
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These methods now return a raw JAX array when xp is not np, but the return annotation still claims aa.Array2D. This mismatch can mislead callers and type checkers; consider updating the annotation to a Union[...] (or overloads) that reflects both the NumPy-wrapped and JAX-raw return types.

Copilot uses AI. Check for mistakes.

def radial_eigen_value_from(self, grid, xp=np) -> aa.Array2D:
"""
Expand All @@ -267,7 +272,12 @@ def radial_eigen_value_from(self, grid, xp=np) -> aa.Array2D:
convergence = self.convergence_2d_via_hessian_from(grid=grid, xp=xp)
shear = self.shear_yx_2d_via_hessian_from(grid=grid, xp=xp)

return aa.Array2D(values=1 - convergence + shear.magnitudes, mask=grid.mask)
if xp is np:
return aa.Array2D(
values=1 - convergence + shear.magnitudes, mask=grid.mask
)
magnitudes = xp.sqrt(shear[:, 0] ** 2 + shear[:, 1] ** 2)
return 1 - convergence + magnitudes
Comment on lines +275 to +280
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return type annotation is aa.Array2D, but this function now returns a raw array on the JAX path (xp is not np). Update the annotation (e.g., Union[aa.Array2D, <array type>] / overloads) so the public API matches actual behavior.

Copilot uses AI. Check for mistakes.

def magnification_2d_from(self, grid, xp=np) -> aa.Array2D:
"""
Expand All @@ -288,7 +298,9 @@ def magnification_2d_from(self, grid, xp=np) -> aa.Array2D:

det_A = (1 - hessian_xx) * (1 - hessian_yy) - hessian_xy * hessian_yx

return aa.Array2D(values=1 / det_A, mask=grid.mask)
if xp is np:
return aa.Array2D(values=1 / det_A, mask=grid.mask)
return 1 / det_A
Comment on lines +301 to +303
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

magnification_2d_from is annotated to return aa.Array2D, but returns a raw array when xp is not np. Please update the return type annotation (and/or add overloads) to reflect the conditional wrapper behavior.

Copilot uses AI. Check for mistakes.

def deflections_yx_scalar(self, y, x, pixel_scales):
"""
Expand Down Expand Up @@ -465,7 +477,9 @@ def convergence_2d_via_hessian_from(self, grid, xp=np) -> aa.ArrayIrregular:

convergence = 0.5 * (hessian_yy + hessian_xx)

return aa.ArrayIrregular(values=convergence)
if xp is np:
return aa.ArrayIrregular(values=convergence)
return convergence
Comment on lines +480 to +482
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convergence_2d_via_hessian_from is annotated as returning aa.ArrayIrregular, but returns a raw array on the JAX path. Adjust the return type annotation (e.g., Union[aa.ArrayIrregular, <array type>] / overloads) so callers can rely on the signature.

Copilot uses AI. Check for mistakes.

def shear_yx_2d_via_hessian_from(self, grid, xp=np) -> ShearYX2DIrregular:
"""
Expand Down Expand Up @@ -506,7 +520,9 @@ def shear_yx_2d_via_hessian_from(self, grid, xp=np) -> ShearYX2DIrregular:

shear_yx_2d = xp.stack([gamma_2, gamma_1], axis=-1)

return ShearYX2DIrregular(values=shear_yx_2d, grid=grid)
if xp is np:
return ShearYX2DIrregular(values=shear_yx_2d, grid=grid)
return shear_yx_2d
Comment on lines +523 to +525
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shear_yx_2d_via_hessian_from is annotated to return ShearYX2DIrregular, but returns a raw (N, 2) array when xp is not np. Please update the return annotation (or add overloads) to reflect the JAX-compatible return type.

Copilot uses AI. Check for mistakes.

def magnification_2d_via_hessian_from(self, grid, xp=np) -> aa.ArrayIrregular:
"""
Expand Down Expand Up @@ -534,7 +550,9 @@ def magnification_2d_via_hessian_from(self, grid, xp=np) -> aa.ArrayIrregular:

det_A = (1 - hessian_xx) * (1 - hessian_yy) - hessian_xy * hessian_yx

return aa.ArrayIrregular(values=1.0 / det_A)
if xp is np:
return aa.ArrayIrregular(values=1.0 / det_A)
return 1.0 / det_A
Comment on lines +553 to +555
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

magnification_2d_via_hessian_from is annotated as returning aa.ArrayIrregular, but returns a raw array when xp is not np. Update the return type annotation (or add overloads) to match the conditional wrapper behavior.

Copilot uses AI. Check for mistakes.

def contour_list_from(self, grid, contour_array):
grid_contour = aa.Grid2DContour(
Expand Down
Loading