From 31e40f78d0e17506199ae69bce4d2e87ba3cc3d5 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Fri, 6 Mar 2026 16:05:12 +0000 Subject: [PATCH] fix JAX jit boundary in LensCalc + document decorator/JAX patterns All six hessian-derived LensCalc methods now guard autoarray wrapping with `if xp is np:` so they return a raw jax.Array on the JAX path. This allows them to be called directly inside jax.jit without the TypeError that occurred when an ArrayIrregular or Array2D was returned as the JIT output. Methods fixed: - convergence_2d_via_hessian_from - shear_yx_2d_via_hessian_from - magnification_2d_via_hessian_from - magnification_2d_from - tangential_eigen_value_from (also adds jnp.sqrt for shear magnitudes) - radial_eigen_value_from (same) CLAUDE.md updated with decorator system overview and the if-xp-is-np guard pattern for functions at the jax.jit boundary. Co-Authored-By: Claude Sonnet 4.6 --- CLAUDE.md | 67 +++++++++++++++++++++++++++++++++ autogalaxy/operate/lens_calc.py | 30 ++++++++++++--- 2 files changed, 91 insertions(+), 6 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 327acf11..df35a9e9 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -78,6 +78,32 @@ Each dataset type has an `Analysis*` class that implements `log_likelihood_funct These inherit from `AnalysisDataset` → `Analysis` (in `analysis/analysis/`), which inherits `af.Analysis`. The `log_likelihood_function` builds a `Fit*` object from the `af.ModelInstance` and returns its `figure_of_merit`. +### Decorator System (from autoarray) + +Profile methods that consume a grid and return an array, grid, or vector use decorators from `autoarray.structures.decorators`. These ensure the **output type matches the input grid type**: + +| Decorator | `Grid2D` input | `Grid2DIrregular` input | +|---|---|---| +| `@aa.grid_dec.to_array` | `Array2D` | `ArrayIrregular` | +| `@aa.grid_dec.to_grid` | `Grid2D` | `Grid2DIrregular` | +| `@aa.grid_dec.to_vector_yx` | `VectorYX2D` | `VectorYX2DIrregular` | + +The `@aa.grid_dec.transform` decorator (always stacked below the output decorator) shifts and rotates the grid to the profile's reference frame before passing it to the function body. + +The canonical stacking order is: +```python +@aa.grid_dec.to_array # outermost: wraps output +@aa.grid_dec.transform # innermost: transforms grid +def convergence_2d_from(self, grid, xp=np, **kwargs): + y = grid.array[:, 0] # use .array to get raw numpy/jax array + x = grid.array[:, 1] + return ... # return raw array; decorator wraps it +``` + +**Key rule**: the function body must return a **raw array** (not an autoarray). The decorator handles wrapping. Access grid coordinates via `grid.array[:, 0]` / `grid.array[:, 1]` (not `grid[:, 0]`), because after `@transform` the grid is still an autoarray object and `.array` is the safe way to extract the underlying data for both numpy and jax backends. + +See PyAutoArray's `CLAUDE.md` for full details on the decorator internals. + ### JAX Support The codebase is designed so that **NumPy is the default everywhere and JAX is opt-in**. JAX is never imported at module level — it is only imported locally inside functions when explicitly requested. @@ -100,6 +126,28 @@ When adding a new function that should support JAX: 4. Add a JAX implementation in the guarded branch (e.g. `jax.jacfwd`, `jnp.vectorize`) 5. Verify correctness by comparing both paths in `autogalaxy_workspace_test/scripts/` +### JAX and autoarray wrappers at the `jax.jit` boundary + +Autoarray types (`Array2D`, `ArrayIrregular`, `VectorYX2DIrregular`, etc.) are **not registered as JAX pytrees**. This means: + +- Constructing them **inside** a JIT trace is fine (Python code runs normally during tracing) +- **Returning** them as the output of a `jax.jit`-compiled function **fails** with `TypeError: ... is not a valid JAX type` + +Functions decorated with `@aa.grid_dec.to_array` / `@to_vector_yx` wrap their return value in an autoarray type. This wrapping is safe for intermediate calls (the autoarray object is consumed by downstream Python code). However, if such a function is the **outermost call** inside a `jax.jit` lambda, its return value will fail at the JIT boundary. + +The solution is the **`if xp is np:` guard** in the function body: + +```python +def convergence_2d_via_hessian_from(self, grid, xp=np): + convergence = 0.5 * (hessian_yy + hessian_xx) + + if xp is np: + return aa.ArrayIrregular(values=convergence) # numpy: wrapped + return convergence # jax: raw jax.Array +``` + +This pattern is applied throughout `autogalaxy/operate/lens_calc.py`. Functions that are only ever called as intermediate steps (e.g. `deflections_yx_2d_from`) do NOT need this guard — their autoarray wrappers are never the JIT output. + ### Linear Light Profiles & Inversions `LightProfileLinear` subclasses do not take an `intensity` parameter—it is solved via a linear inversion (provided by `autoarray`). The `GalaxiesToInversion` class (`galaxy/to_inversion.py`) handles converting galaxies with linear profiles or pixelizations into the inversion objects needed by `autoarray`. @@ -147,3 +195,22 @@ When importing `autogalaxy as ag`: - `ag.ps.*` – point sources - `ag.Galaxy`, `ag.Galaxies` - `ag.FitImaging`, `ag.AnalysisImaging`, `ag.SimulatorImaging` + +## Line Endings — Always Unix (LF) + +All files in this project **must use Unix line endings (LF, `\n`)**. Windows/DOS line endings (CRLF, `\r\n`) will break Python files on HPC systems. + +**When writing or editing any file**, always produce Unix line endings. Never write `\r\n` line endings. + +After creating or copying files, verify and convert if needed: + +```bash +# Check for DOS line endings +file autogalaxy/galaxy/galaxy.py # should say "ASCII text", not "CRLF" + +# Convert all Python files in the project +find . -type f -name "*.py" | xargs dos2unix +``` + +Prefer simple shell commands. +Avoid chaining with && or pipes. \ No newline at end of file diff --git a/autogalaxy/operate/lens_calc.py b/autogalaxy/operate/lens_calc.py index 55055f23..b678315b 100644 --- a/autogalaxy/operate/lens_calc.py +++ b/autogalaxy/operate/lens_calc.py @@ -247,7 +247,12 @@ def tangential_eigen_value_from(self, grid, xp=np) -> aa.Array2D: convergence = self.convergence_2d_via_hessian_from(grid=grid, xp=xp) shear_yx = self.shear_yx_2d_via_hessian_from(grid=grid, xp=xp) - return aa.Array2D(values=1 - convergence - shear_yx.magnitudes, mask=grid.mask) + if xp is np: + return aa.Array2D( + values=1 - convergence - shear_yx.magnitudes, mask=grid.mask + ) + magnitudes = xp.sqrt(shear_yx[:, 0] ** 2 + shear_yx[:, 1] ** 2) + return 1 - convergence - magnitudes def radial_eigen_value_from(self, grid, xp=np) -> aa.Array2D: """ @@ -267,7 +272,12 @@ def radial_eigen_value_from(self, grid, xp=np) -> aa.Array2D: convergence = self.convergence_2d_via_hessian_from(grid=grid, xp=xp) shear = self.shear_yx_2d_via_hessian_from(grid=grid, xp=xp) - return aa.Array2D(values=1 - convergence + shear.magnitudes, mask=grid.mask) + if xp is np: + return aa.Array2D( + values=1 - convergence + shear.magnitudes, mask=grid.mask + ) + magnitudes = xp.sqrt(shear[:, 0] ** 2 + shear[:, 1] ** 2) + return 1 - convergence + magnitudes def magnification_2d_from(self, grid, xp=np) -> aa.Array2D: """ @@ -288,7 +298,9 @@ def magnification_2d_from(self, grid, xp=np) -> aa.Array2D: det_A = (1 - hessian_xx) * (1 - hessian_yy) - hessian_xy * hessian_yx - return aa.Array2D(values=1 / det_A, mask=grid.mask) + if xp is np: + return aa.Array2D(values=1 / det_A, mask=grid.mask) + return 1 / det_A def deflections_yx_scalar(self, y, x, pixel_scales): """ @@ -465,7 +477,9 @@ def convergence_2d_via_hessian_from(self, grid, xp=np) -> aa.ArrayIrregular: convergence = 0.5 * (hessian_yy + hessian_xx) - return aa.ArrayIrregular(values=convergence) + if xp is np: + return aa.ArrayIrregular(values=convergence) + return convergence def shear_yx_2d_via_hessian_from(self, grid, xp=np) -> ShearYX2DIrregular: """ @@ -506,7 +520,9 @@ def shear_yx_2d_via_hessian_from(self, grid, xp=np) -> ShearYX2DIrregular: shear_yx_2d = xp.stack([gamma_2, gamma_1], axis=-1) - return ShearYX2DIrregular(values=shear_yx_2d, grid=grid) + if xp is np: + return ShearYX2DIrregular(values=shear_yx_2d, grid=grid) + return shear_yx_2d def magnification_2d_via_hessian_from(self, grid, xp=np) -> aa.ArrayIrregular: """ @@ -534,7 +550,9 @@ def magnification_2d_via_hessian_from(self, grid, xp=np) -> aa.ArrayIrregular: det_A = (1 - hessian_xx) * (1 - hessian_yy) - hessian_xy * hessian_yx - return aa.ArrayIrregular(values=1.0 / det_A) + if xp is np: + return aa.ArrayIrregular(values=1.0 / det_A) + return 1.0 / det_A def contour_list_from(self, grid, contour_array): grid_contour = aa.Grid2DContour(