Conversation
Replaces the finite-difference hessian_from with a dual-path implementation: - xp=np (default): delegates to _hessian_via_finite_difference, no JAX import - xp=jnp: delegates to _hessian_via_jax which uses jax.jacfwd on a new deflections_yx_scalar helper, supporting both Grid2D and Grid2DIrregular Also fixes shear_yx_2d_via_hessian_from to use grid.shape[0] instead of grid.shape_slim (incompatible with Grid2DIrregular), and makes magnification_2d_via_hessian_from return a raw jax.Array when xp=jnp to avoid wrapping a traced value in ArrayIrregular. Updates CLAUDE.md to document the NumPy-default / JAX-opt-in design pattern. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Remove precompute_jacobian decorator, jacobian_from (np.gradient), convergence_2d_via_jacobian_from, and shear_yx_2d_via_jacobian_from. All were redundant with the hessian path since A = I - H. Rewire tangential_eigen_value_from, radial_eigen_value_from, and magnification_2d_from to call hessian_from directly. Restore jacobian_from as a thin public wrapper over hessian_from that returns [[1-hxx, -hxy], [-hyx, 1-hyy]], supporting both xp=np and xp=jnp. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
There was a problem hiding this comment.
Pull request overview
This pull request refactors OperateDeflections to make Hessian-based derivatives the primary path for lensing quantities, introduces a JAX auto-diff Hessian path selected via an xp backend parameter, and removes legacy Jacobian-precompute plumbing. It also updates repository guidance around JAX usage and workspace scripting style, and adjusts tests to match the new API.
Changes:
- Replaced Jacobian-based convergence/shear computations with Hessian-based equivalents; removed
precompute_jacobianand legacy Jacobian-via-gradient methods. - Added a JAX Hessian computation path using
jax.jacfwd, selected viaxp, and introducedjacobian_fromcomputed asA = I - H. - Updated docs (
CLAUDE.md) about the NumPy-default / JAX-opt-inxppattern and workspace script prose style; updated/trimmed deflections tests accordingly.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 7 comments.
| File | Description |
|---|---|
| autogalaxy/operate/deflections.py | Refactors Hessian/Jacobian computation, adds JAX Hessian path via xp, removes legacy Jacobian utilities. |
| test_autogalaxy/operate/test_deflections.py | Updates tests to reflect Hessian-first API and revised Jacobian behavior. |
| CLAUDE.md | Clarifies JAX opt-in conventions (xp pattern) and workspace script commenting style. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
autogalaxy/operate/deflections.py
Outdated
| hessian_yy, hessian_xy, hessian_yx, hessian_xx = self.hessian_from(grid=grid) | ||
|
|
||
| return aa.ArrayIrregular(values=0.5 * (hessian_yy + hessian_xx)) |
There was a problem hiding this comment.
convergence_2d_via_hessian_from no longer accepts an xp parameter and always calls hessian_from with its default xp=np. This prevents convergence from being computed via the JAX Hessian path, which is inconsistent with hessian_from(..., xp=...) / magnification_2d_via_hessian_from(..., xp=...). Consider adding xp=np here and passing it through to hessian_from.
autogalaxy/operate/deflections.py
Outdated
| hessian_yy, hessian_xy, hessian_yx, hessian_xx = self.hessian_from(grid=grid) | ||
|
|
||
| gamma_1 = 0.5 * (hessian_xx - hessian_yy) | ||
| gamma_2 = hessian_xy | ||
|
|
||
| shear_yx_2d = np.zeros(shape=(grid.shape_slim, 2)) | ||
| shear_yx_2d = np.zeros(shape=(grid.shape[0], 2)) | ||
|
|
There was a problem hiding this comment.
shear_yx_2d_via_hessian_from always uses NumPy (np.zeros) and calls hessian_from without passing xp, so it cannot use the JAX derivative path even if the caller is operating with xp=jnp. If JAX support is intended here, add an xp=np parameter, pass it through to hessian_from, and allocate the output with xp.zeros (or otherwise ensure JAX-compatible arrays).
autogalaxy/operate/deflections.py
Outdated
| def shear_yx_2d_via_hessian_from( | ||
| self, grid, buffer: float = 0.01 | ||
| self, grid | ||
| ) -> ShearYX2DIrregular: | ||
| """ |
There was a problem hiding this comment.
The docstring for shear_yx_2d_via_hessian_from describes returning a ShearYX2D structure, but the function is annotated and implemented to return ShearYX2DIrregular. Update the docstring (and/or return type) so the documented API matches the implementation.
autogalaxy/operate/deflections.py
Outdated
| if xp is not np: | ||
| return xp.array(1.0 / det_A) |
There was a problem hiding this comment.
magnification_2d_via_hessian_from returns different types depending on xp: an aa.ArrayIrregular for NumPy but a raw xp.array for JAX. This makes the method’s return type inconsistent and contradicts the -> aa.ArrayIrregular annotation. Either always wrap in the same AutoArray structure (with JAX values inside), or update the annotation / docstring to reflect a union return type and ensure downstream callers handle both.
| if xp is not np: | |
| return xp.array(1.0 / det_A) |
autogalaxy/operate/deflections.py
Outdated
| @@ -193,105 +180,153 @@ def tangential_eigen_value_from(self, grid, jacobian=None) -> aa.Array2D: | |||
| grid | |||
| The 2D grid of (y,x) arc-second coordinates the deflection angles and tangential eigen values are computed | |||
| on. | |||
| jacobian | |||
| A precomputed lensing jacobian, which is passed throughout the `CalcLens` functions for efficiency. | |||
| """ | |||
| convergence = self.convergence_2d_via_jacobian_from( | |||
| grid=grid, jacobian=jacobian | |||
| ) | |||
|
|
|||
| shear_yx = self.shear_yx_2d_via_jacobian_from(grid=grid, jacobian=jacobian) | |||
| convergence = self.convergence_2d_via_hessian_from(grid=grid) | |||
| shear_yx = self.shear_yx_2d_via_hessian_from(grid=grid) | |||
|
|
|||
| return aa.Array2D(values=1 - convergence - shear_yx.magnitudes, mask=grid.mask) | |||
There was a problem hiding this comment.
These eigenvalue helpers no longer accept xp, so they always go through the NumPy path (via convergence_2d_via_hessian_from / shear_yx_2d_via_hessian_from). If JAX support via the xp pattern is intended for eigenvalue / critical-curve calculations, consider adding xp=np here and passing it through to the underlying convergence/shear (and ultimately hessian_from).
autogalaxy/operate/deflections.py
Outdated
| def radial_eigen_value_from(self, grid) -> aa.Array2D: | ||
| """ | ||
| Returns the radial eigen values of lensing jacobian, which are given by the expression: | ||
|
|
||
| radial_eigen_value = 1 - convergence + shear | ||
| `radial_eigen_value = 1 - convergence + shear` | ||
|
|
||
| Parameters | ||
| ---------- | ||
| grid | ||
| The 2D grid of (y,x) arc-second coordinates the deflection angles and radial eigen values are computed on. | ||
| jacobian | ||
| A precomputed lensing jacobian, which is passed throughout the `CalcLens` functions for efficiency. | ||
| """ | ||
| convergence = self.convergence_2d_via_jacobian_from( | ||
| grid=grid, jacobian=jacobian | ||
| ) | ||
|
|
||
| shear = self.shear_yx_2d_via_jacobian_from(grid=grid, jacobian=jacobian) | ||
| convergence = self.convergence_2d_via_hessian_from(grid=grid) | ||
| shear = self.shear_yx_2d_via_hessian_from(grid=grid) | ||
|
|
||
| return aa.Array2D(values=1 - convergence + shear.magnitudes, mask=grid.mask) |
There was a problem hiding this comment.
radial_eigen_value_from also lacks an xp parameter and therefore can’t take advantage of the JAX Hessian path. If the intention of this refactor is that Hessian-based calculations are JAX-acceleratable via xp, it would be more consistent to thread xp through here as well.
| def _hessian_single(y_scalar, x_scalar): | ||
| return jnp.stack( | ||
| jax.jacfwd(self.deflections_yx_scalar, argnums=(0, 1))( | ||
| y_scalar, x_scalar, pixel_scales | ||
| ) | ||
| ) | ||
|
|
||
| h = jnp.vectorize(_hessian_single, signature="(),()->(i,i)")(y, x) |
There was a problem hiding this comment.
The JAX Hessian path uses jnp.vectorize over per-point jax.jacfwd calls. In JAX this is typically a Python-level loop and can be very slow / inhibit jit performance. Prefer a jax.vmap-based implementation (e.g., vmap a jacfwd’d scalar function, or jacfwd a batched function) to get compiled vectorization.
| def _hessian_single(y_scalar, x_scalar): | |
| return jnp.stack( | |
| jax.jacfwd(self.deflections_yx_scalar, argnums=(0, 1))( | |
| y_scalar, x_scalar, pixel_scales | |
| ) | |
| ) | |
| h = jnp.vectorize(_hessian_single, signature="(),()->(i,i)")(y, x) | |
| jac_fn = jax.jacfwd(self.deflections_yx_scalar, argnums=(0, 1)) | |
| def _hessian_single(y_scalar, x_scalar): | |
| return jnp.stack(jac_fn(y_scalar, x_scalar, pixel_scales)) | |
| # Use jax.vmap for efficient batched evaluation over the grid | |
| h = jax.vmap(_hessian_single, in_axes=(0, 0))(y, x) |
…_2d hessian methods Thread xp=np through convergence_2d_via_hessian_from, shear_yx_2d_via_hessian_from, tangential_eigen_value_from, radial_eigen_value_from, and magnification_2d_from so all hessian-derived quantities support the JAX path consistently. For each method, xp is passed through to hessian_from; when xp is not numpy the result is returned as a raw jax.Array rather than an autoarray wrapper (which cannot be constructed during a jax.jit trace). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
…o LensCalc - Rename operate/deflections.py -> operate/lens_calc.py - Rename class OperateDeflections -> LensCalc throughout - LensCalc.__init__ now accepts optional potential_2d_from callable - from_mass_obj and from_tracer capture potential_2d_from automatically - fermat_potential_from moved into LensCalc (removed from MassProfile, Galaxy, Galaxies) - Update all imports, call sites, tests, and docs in autogalaxy Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This pull request refactors the lensing Jacobian and Hessian computation logic in
autogalaxy/operate/deflections.pyto make Hessian-based calculations the default, unify the API for NumPy and JAX backends, and remove legacy Jacobian-based methods. It also updates documentation to clarify JAX usage and workspace script style. The changes improve flexibility, performance, and code clarity, especially for JAX integration.Lensing Jacobian & Hessian Refactor
precompute_jacobiandecorator. (autogalaxy/operate/deflections.py) [1] [2] [3] [4]xpparameter. JAX is now only imported locally when requested, and Hessian computation uses either finite-difference (NumPy) or auto-differentiation (jax.jacfwd) for JAX. (autogalaxy/operate/deflections.py) [1] [2] [3] [4]JAX Integration & Documentation
CLAUDE.mdto clarify that NumPy is the default backend and JAX is opt-in, imported only within functions when needed. Provided clear instructions for adding JAX support to new functions and described the testing strategy for JAX vs. NumPy paths. (CLAUDE.md)CLAUDE.md)Miscellaneous
autogalaxy/operate/deflections.py)test_autogalaxy/operate/test_deflections.py)