From a41d142471a8dc9d3a3a9a245fb4fadc7f613a13 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 1 Apr 2026 11:08:13 +0100 Subject: [PATCH 1/4] Add jax_zero_contour-based critical curve / caustic tracing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replaces the dense marching-squares grid evaluation with a JAX contour tracer that evaluates the eigen-value function only along the curve itself — far fewer function calls, no input grid required. New LensCalc methods - _make_eigen_fn(kind, pixel_scales): unified JAX scalar function factory (tangential/radial differ only in sign; hessian_xy is symmetrised) - _init_guess_from_coarse_grid: 25x25 seed search via marching squares - _critical_curve_list_via_zero_contour: shared private implementation - _caustic_list_via_zero_contour: shared private implementation - tangential/radial_critical_curve_list_via_zero_contour_from - tangential/radial_caustic_list_via_zero_contour_from - einstein_radius_list/scalar_via_zero_contour_from Visualizer config (visualize/general.yaml) - New key general.critical_curves_method (default: zero_contour) - plot_utils._critical_curves_method() reads config and dispatches - Both _critical_curves_from and _caustics_from respect the setting - Gracefully returns [] when no curve found (no ValueError in plot path) Also fixes pre-existing bug in galaxy_plots.subplot_of_mass_profiles where output filenames were prefixed with "subplot_" contrary to the intended behaviour described in commit 585acb66. Co-Authored-By: Claude Sonnet 4.6 --- CLAUDE.md | 34 +- autogalaxy/config/visualize/general.yaml | 1 + autogalaxy/galaxy/plot/galaxy_plots.py | 2 +- autogalaxy/operate/lens_calc.py | 486 +++++++++++++++++- autogalaxy/plot/plot_utils.py | 83 ++- test_autogalaxy/plot/mat_wrap/test_visuals.py | 13 +- 6 files changed, 586 insertions(+), 33 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index df35a9e9..e037dbad 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -9,10 +9,10 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co pip install -e ".[dev]" ``` -### Run Tests -```bash -# All tests -python -m pytest test_autogalaxy/ +### Run Tests +```bash +# All tests +python -m pytest test_autogalaxy/ # Single test file python -m pytest test_autogalaxy/galaxy/test_galaxy.py @@ -20,13 +20,23 @@ python -m pytest test_autogalaxy/galaxy/test_galaxy.py # Single test python -m pytest test_autogalaxy/galaxy/test_galaxy.py::TestGalaxy::test_name -# With output -python -m pytest test_autogalaxy/imaging/test_fit_imaging.py -s -``` - -### Formatting -```bash -black autogalaxy/ +# With output +python -m pytest test_autogalaxy/imaging/test_fit_imaging.py -s +``` + +### Codex / sandboxed runs + +When running Python from Codex or any restricted environment, set writable cache directories so `numba` and `matplotlib` do not fail on unwritable home or source-tree paths: + +```bash +NUMBA_CACHE_DIR=/tmp/numba_cache MPLCONFIGDIR=/tmp/matplotlib python -m pytest test_autogalaxy/ +``` + +This workspace is often imported from `/mnt/c/...` and Codex may not be able to write to module `__pycache__` directories or `/home/jammy/.cache`, which can cause import-time `numba` caching failures without this override. + +### Formatting +```bash +black autogalaxy/ ``` ## Architecture @@ -213,4 +223,4 @@ find . -type f -name "*.py" | xargs dos2unix ``` Prefer simple shell commands. -Avoid chaining with && or pipes. \ No newline at end of file +Avoid chaining with && or pipes. diff --git a/autogalaxy/config/visualize/general.yaml b/autogalaxy/config/visualize/general.yaml index 51900f59..1fe89f21 100644 --- a/autogalaxy/config/visualize/general.yaml +++ b/autogalaxy/config/visualize/general.yaml @@ -2,6 +2,7 @@ general: backend: default imshow_origin: upper zoom_around_mask: true + critical_curves_method: zero_contour inversion: reconstruction_vmax_factor: 0.5 zoom: diff --git a/autogalaxy/galaxy/plot/galaxy_plots.py b/autogalaxy/galaxy/plot/galaxy_plots.py index 35f9b32d..191dd64f 100644 --- a/autogalaxy/galaxy/plot/galaxy_plots.py +++ b/autogalaxy/galaxy/plot/galaxy_plots.py @@ -143,4 +143,4 @@ def _deflections_x(mp): ) plt.tight_layout() - _save_subplot(fig, output_path, f"subplot_{name}", output_format) + _save_subplot(fig, output_path, name, output_format) diff --git a/autogalaxy/operate/lens_calc.py b/autogalaxy/operate/lens_calc.py index 4d630444..05ffc2f8 100644 --- a/autogalaxy/operate/lens_calc.py +++ b/autogalaxy/operate/lens_calc.py @@ -943,9 +943,491 @@ def einstein_mass_angular_from( if len(einstein_mass_angular_list) > 1: logger.info( """ - There are multiple tangential critical curves, and the computed Einstein mass is the sum of - all of them. Check the `einstein_mass_list_from` function for the individual Einstein. + There are multiple tangential critical curves, and the computed Einstein mass is the sum of + all of them. Check the `einstein_mass_list_from` function for the individual Einstein. """ ) return einstein_mass_angular_list[0] + + # ------------------------------------------------------------------------- + # jax_zero_contour-based critical curves / caustics + # ------------------------------------------------------------------------- + + def _make_eigen_fn(self, kind: str, pixel_scales=(0.05, 0.05)): + """Return a JAX scalar function ``f(pos) -> eigen_value``. + + ``pos`` has shape ``(2,)`` (y, x) — ``ZeroSolver.newton`` is vmapped + over the init_guess rows, passing each row slice individually to + ``jax.lax.custom_root(f, ...)``. + + The function is fully JAX-differentiable: ``ZeroSolver`` calls + ``jacfwd``/``jacrev`` on it internally (Newton's method on the eigen + value requires the second derivative of the deflections). + + ``hessian_xy`` is symmetrised as ``0.5 * (H[0,1] + H[1,0])`` to guard + against numerically non-curl-free deflection fields. + + Parameters + ---------- + kind + ``"tangential"`` (eigen value = ``1 - κ - |γ|``) or + ``"radial"`` (``1 - κ + |γ|``). + pixel_scales + Forwarded to ``deflections_yx_scalar`` for its internal + single-pixel ``Mask2D``. + """ + import jax + import jax.numpy as jnp + from jax.tree_util import Partial + + # Capture as local names so the closure holds no `self` reference. + # ZeroSolver.zero_contour_finder is jit-compiled with `f` as a + # non-static argument, so it must be a JAX pytree. Wrapping in + # Partial with no dynamic args gives a pytree whose treedef is the + # closure itself and whose leaves list is empty. + _deflections_yx_scalar = self.deflections_yx_scalar + _pixel_scales = pixel_scales + _sign = -1.0 if kind == "tangential" else 1.0 + + def _f(pos): + y, x = pos[0], pos[1] + H = jnp.stack( + jax.jacfwd(_deflections_yx_scalar, argnums=(0, 1))( + y, x, _pixel_scales + ) + ) + convergence = 0.5 * (H[0, 0] + H[1, 1]) + gamma_1 = 0.5 * (H[1, 1] - H[0, 0]) + gamma_2 = 0.5 * (H[0, 1] + H[1, 0]) # symmetrised + shear = jnp.sqrt(gamma_1 ** 2 + gamma_2 ** 2) + return 1.0 - convergence + _sign * shear + + return Partial(_f) + + def _make_tangential_eigen_fn(self, pixel_scales=(0.05, 0.05)): + """Return a JAX scalar function ``f(pos) -> tangential_eigen_value``.""" + return self._make_eigen_fn(kind="tangential", pixel_scales=pixel_scales) + + def _make_radial_eigen_fn(self, pixel_scales=(0.05, 0.05)): + """Return a JAX scalar function ``f(pos) -> radial_eigen_value``.""" + return self._make_eigen_fn(kind="radial", pixel_scales=pixel_scales) + + def _init_guess_from_coarse_grid( + self, + kind: str = "tangential", + grid_shape: Tuple[int, int] = (25, 25), + grid_extent: float = 3.0, + ): + """Return a rough initial-guess array near the critical curve. + + Evaluates the eigen values on a very coarse uniform grid (default + 25 × 25 = 625 evaluations, versus ~250 000 for the production grid) + and runs the existing marching-squares ``contour_list_from`` on it to + find seed points — one per distinct curve segment. The midpoint of + each coarse segment is taken as the initial guess. + + Parameters + ---------- + kind + ``"tangential"`` or ``"radial"``. + grid_shape + Number of pixels along each axis of the coarse evaluation grid. + grid_extent + Half-width of the coarse grid in arc-seconds. + + Returns + ------- + jax.numpy.ndarray of shape ``(n_curves, 2)`` + """ + import jax.numpy as jnp + + pixel_scale = 2.0 * grid_extent / grid_shape[0] + grid = aa.Grid2D.uniform( + shape_native=grid_shape, + pixel_scales=(pixel_scale, pixel_scale), + ) + + if kind == "tangential": + eigen_values = self.tangential_eigen_value_from(grid=grid) + else: + eigen_values = self.radial_eigen_value_from(grid=grid) + + coarse_curves = self.contour_list_from( + grid=grid, contour_array=eigen_values + ) + + if not coarse_curves: + raise ValueError( + f"No {kind} critical curve found within the coarse grid " + f"(extent ±{grid_extent} arcsec, shape {grid_shape}). " + "Pass an explicit `init_guess` or increase `grid_extent`." + ) + + seeds = [curve[len(curve) // 2] for curve in coarse_curves] + return jnp.array(seeds) + + def _critical_curve_list_via_zero_contour( + self, + kind: str, + init_guess=None, + delta: float = 0.05, + N: int = 500, + pixel_scales: Tuple[float, float] = (0.05, 0.05), + tol: float = 1e-6, + max_newton: int = 5, + ) -> List[aa.Grid2DIrregular]: + """Shared implementation for tangential/radial critical curves. + + Parameters + ---------- + kind + ``"tangential"`` or ``"radial"``. + init_guess + JAX or NumPy array of shape ``(n, 2)``. ``None`` triggers an + automatic coarse-grid seed search. + delta + Arc-second step size along the contour. + N + Maximum steps in each direction from each seed. + pixel_scales + Pixel scales passed to ``deflections_yx_scalar``. + tol + Newton's method convergence tolerance. + max_newton + Maximum Newton iterations per step. + """ + from jax_zero_contour import ZeroSolver + import jax.numpy as jnp + + if init_guess is None: + try: + init_guess = self._init_guess_from_coarse_grid(kind=kind) + except ValueError: + return [] + + init_guess = jnp.atleast_2d(jnp.asarray(init_guess)) + f = self._make_eigen_fn(kind=kind, pixel_scales=pixel_scales) + solver = ZeroSolver(tol=tol, max_newton=max_newton) + paths, _ = solver.zero_contour_finder(f, init_guess, delta=delta, N=N) + paths = ZeroSolver.path_reduce(paths) + + return [ + aa.Grid2DIrregular(values=np.array(path)) + for path in paths["path"] + if len(path) > 1 + ] + + def tangential_critical_curve_list_via_zero_contour_from( + self, + init_guess=None, + delta: float = 0.05, + N: int = 500, + pixel_scales: Tuple[float, float] = (0.05, 0.05), + tol: float = 1e-6, + max_newton: int = 5, + ) -> List[aa.Grid2DIrregular]: + """ + Returns tangential critical curves using the ``jax_zero_contour`` package. + + Unlike ``tangential_critical_curve_list_from``, this method does not + evaluate lensing quantities on a dense uniform grid. Instead it traces + the zero contour of the tangential eigen value directly, evaluating the + function only along the curve itself. + + The algorithm (from ``jax_zero_contour.ZeroSolver``): + + 1. Newton's method projects each initial guess onto the exact zero + contour of the tangential eigen value. + 2. Euler-Lagrange (gradient-perpendicular) stepping traces the contour + in both directions from each projected seed point. + 3. Tracing stops when the path closes, hits an endpoint, or exhausts + ``N`` steps. + + Parameters + ---------- + init_guess + JAX or NumPy array of shape ``(n, 2)`` with rough ``(y, x)`` + positions near the tangential critical curve — one seed per + distinct curve. If ``None`` a coarse 25 × 25 grid scan is used + to find seed points automatically. + delta + Arc-second step size along the contour. Smaller values give + denser, more accurate curves but require a larger ``N`` to trace + the same total length. + N + Maximum number of steps in each direction from each seed point. + The traced path has at most ``2N + 1`` points per seed. + pixel_scales + Pixel scales passed to ``deflections_yx_scalar`` for its internal + single-pixel mask. + tol + Newton's method convergence tolerance (forwarded to ``ZeroSolver``). + max_newton + Maximum Newton iterations per step (forwarded to ``ZeroSolver``). + + Returns + ------- + List[aa.Grid2DIrregular] + One ``Grid2DIrregular`` per traced contour segment, matching the + return type of ``tangential_critical_curve_list_from``. + """ + return self._critical_curve_list_via_zero_contour( + kind="tangential", + init_guess=init_guess, + delta=delta, + N=N, + pixel_scales=pixel_scales, + tol=tol, + max_newton=max_newton, + ) + + def radial_critical_curve_list_via_zero_contour_from( + self, + init_guess=None, + delta: float = 0.05, + N: int = 500, + pixel_scales: Tuple[float, float] = (0.05, 0.05), + tol: float = 1e-6, + max_newton: int = 5, + ) -> List[aa.Grid2DIrregular]: + """ + Returns radial critical curves using the ``jax_zero_contour`` package. + + Identical to ``tangential_critical_curve_list_via_zero_contour_from`` + except the zero contour of the *radial* eigen value is traced. + + Parameters + ---------- + init_guess + JAX or NumPy array of shape ``(n, 2)`` with rough ``(y, x)`` + positions near the radial critical curve. If ``None`` a coarse + grid scan finds seed points automatically. + delta + Arc-second step size along the contour. + N + Maximum number of steps in each direction from each seed. + pixel_scales + Pixel scales passed to ``deflections_yx_scalar``. + tol + Newton's method convergence tolerance. + max_newton + Maximum Newton iterations per step. + + Returns + ------- + List[aa.Grid2DIrregular] + """ + return self._critical_curve_list_via_zero_contour( + kind="radial", + init_guess=init_guess, + delta=delta, + N=N, + pixel_scales=pixel_scales, + tol=tol, + max_newton=max_newton, + ) + + def _caustic_list_via_zero_contour( + self, + kind: str, + init_guess=None, + delta: float = 0.05, + N: int = 500, + pixel_scales: Tuple[float, float] = (0.05, 0.05), + tol: float = 1e-6, + max_newton: int = 5, + ) -> List[aa.Grid2DIrregular]: + """Shared implementation for tangential/radial caustics.""" + cc_list = self._critical_curve_list_via_zero_contour( + kind=kind, + init_guess=init_guess, + delta=delta, + N=N, + pixel_scales=pixel_scales, + tol=tol, + max_newton=max_newton, + ) + return [ + cc - self.deflections_yx_2d_from(grid=cc) for cc in cc_list + ] + + def tangential_caustic_list_via_zero_contour_from( + self, + init_guess=None, + delta: float = 0.05, + N: int = 500, + pixel_scales: Tuple[float, float] = (0.05, 0.05), + tol: float = 1e-6, + max_newton: int = 5, + ) -> List[aa.Grid2DIrregular]: + """ + Returns tangential caustics by ray-tracing the tangential critical + curves computed via ``tangential_critical_curve_list_via_zero_contour_from``. + + Parameters + ---------- + init_guess + Forwarded to ``tangential_critical_curve_list_via_zero_contour_from``. + delta + Arc-second step size along the contour. + N + Maximum steps per seed direction. + pixel_scales + Pixel scales passed to ``deflections_yx_scalar``. + tol + Newton's method convergence tolerance. + max_newton + Maximum Newton iterations per step. + + Returns + ------- + List[aa.Grid2DIrregular] + """ + return self._caustic_list_via_zero_contour( + kind="tangential", + init_guess=init_guess, + delta=delta, + N=N, + pixel_scales=pixel_scales, + tol=tol, + max_newton=max_newton, + ) + + def radial_caustic_list_via_zero_contour_from( + self, + init_guess=None, + delta: float = 0.05, + N: int = 500, + pixel_scales: Tuple[float, float] = (0.05, 0.05), + tol: float = 1e-6, + max_newton: int = 5, + ) -> List[aa.Grid2DIrregular]: + """ + Returns radial caustics by ray-tracing the radial critical curves + computed via ``radial_critical_curve_list_via_zero_contour_from``. + + Parameters + ---------- + init_guess + Forwarded to ``radial_critical_curve_list_via_zero_contour_from``. + delta + Arc-second step size along the contour. + N + Maximum steps per seed direction. + pixel_scales + Pixel scales passed to ``deflections_yx_scalar``. + tol + Newton's method convergence tolerance. + max_newton + Maximum Newton iterations per step. + + Returns + ------- + List[aa.Grid2DIrregular] + """ + return self._caustic_list_via_zero_contour( + kind="radial", + init_guess=init_guess, + delta=delta, + N=N, + pixel_scales=pixel_scales, + tol=tol, + max_newton=max_newton, + ) + + def einstein_radius_list_via_zero_contour_from( + self, + init_guess=None, + delta: float = 0.05, + N: int = 500, + pixel_scales: Tuple[float, float] = (0.05, 0.05), + tol: float = 1e-6, + max_newton: int = 5, + ) -> List[float]: + """ + Returns a list of Einstein radii from the tangential critical curves + traced via ``tangential_critical_curve_list_via_zero_contour_from``. + + Each Einstein radius is the radius of the circle with the same area + as the corresponding tangential critical curve. + + Parameters + ---------- + init_guess + Forwarded to ``tangential_critical_curve_list_via_zero_contour_from``. + delta + Arc-second step size along the contour. + N + Maximum steps per seed direction. + pixel_scales + Pixel scales passed to ``deflections_yx_scalar``. + tol + Newton's method convergence tolerance. + max_newton + Maximum Newton iterations per step. + + Returns + ------- + List[float] + """ + tangential_critical_curve_list = ( + self.tangential_critical_curve_list_via_zero_contour_from( + init_guess=init_guess, + delta=delta, + N=N, + pixel_scales=pixel_scales, + tol=tol, + max_newton=max_newton, + ) + ) + area_list = self.area_within_curve_list_from( + curve_list=tangential_critical_curve_list + ) + return [np.sqrt(area / np.pi) for area in area_list] + + def einstein_radius_via_zero_contour_from( + self, + init_guess=None, + delta: float = 0.05, + N: int = 500, + pixel_scales: Tuple[float, float] = (0.05, 0.05), + tol: float = 1e-6, + max_newton: int = 5, + ) -> float: + """ + Returns the Einstein radius from the tangential critical curve traced + via ``jax_zero_contour``. + + If there are multiple tangential critical curves the radii are summed, + consistent with ``einstein_radius_from``. + + Parameters + ---------- + init_guess + Forwarded to ``einstein_radius_list_via_zero_contour_from``. + delta + Arc-second step size along the contour. + N + Maximum steps per seed direction. + pixel_scales + Pixel scales passed to ``deflections_yx_scalar``. + tol + Newton's method convergence tolerance. + max_newton + Maximum Newton iterations per step. + + Returns + ------- + float + """ + return sum( + self.einstein_radius_list_via_zero_contour_from( + init_guess=init_guess, + delta=delta, + N=N, + pixel_scales=pixel_scales, + tol=tol, + max_newton=max_newton, + ) + ) diff --git a/autogalaxy/plot/plot_utils.py b/autogalaxy/plot/plot_utils.py index 1c4a0083..e45a915a 100644 --- a/autogalaxy/plot/plot_utils.py +++ b/autogalaxy/plot/plot_utils.py @@ -270,16 +270,48 @@ def plot_grid( ) +def _critical_curves_method(): + """Read ``general.critical_curves_method`` from the visualize config. + + Returns ``"zero_contour"`` (the default) or ``"marching_squares"``. + Any unrecognised value falls back to ``"zero_contour"`` with a warning. + """ + from autoconf import conf + + try: + method = conf.instance["visualize"]["general"]["general"]["critical_curves_method"] + except (KeyError, TypeError): + method = "zero_contour" + + if method not in ("zero_contour", "marching_squares"): + logger.warning( + f"visualize/general.yaml: unrecognised critical_curves_method " + f"'{method}'. Falling back to 'zero_contour'." + ) + return "zero_contour" + return method + + def _caustics_from(mass_obj, grid): """Compute tangential and radial caustics for a mass object via LensCalc. + The algorithm used is controlled by ``general.critical_curves_method`` in + ``visualize/general.yaml``: + + - ``"zero_contour"`` *(default)* — uses ``jax_zero_contour`` to trace the + zero contour of each eigen value directly. No dense evaluation grid is + needed; a coarse 25 × 25 scan finds the seed points automatically. + - ``"marching_squares"`` — evaluates eigen values on the full *grid* and + uses marching squares to find the contours. + Parameters ---------- mass_obj Any object understood by ``LensCalc.from_mass_obj`` (e.g. a :class:`~autogalaxy.galaxy.galaxies.Galaxies` or autolens ``Tracer``). grid : aa.type.Grid2DLike - The grid on which to evaluate the caustics. + The grid on which to evaluate the caustics (used only for the + ``"marching_squares"`` path; ignored by ``"zero_contour"``). Returns ------- @@ -289,8 +321,15 @@ def _caustics_from(mass_obj, grid): from autogalaxy.operate.lens_calc import LensCalc od = LensCalc.from_mass_obj(mass_obj) - tan_ca = od.tangential_caustic_list_from(grid=grid) - rad_ca = od.radial_caustic_list_from(grid=grid) + method = _critical_curves_method() + + if method == "zero_contour": + tan_ca = od.tangential_caustic_list_via_zero_contour_from() + rad_ca = od.radial_caustic_list_via_zero_contour_from() + else: + tan_ca = od.tangential_caustic_list_from(grid=grid) + rad_ca = od.radial_caustic_list_from(grid=grid) + return tan_ca, rad_ca @@ -298,24 +337,28 @@ def _critical_curves_from(mass_obj, grid, tc=None, rc=None): """Compute tangential and radial critical curves for a mass object. If *tc* is already provided it is returned unchanged (along with *rc*), - allowing callers to cache the curves across multiple plot calls. When - *tc* is ``None`` the curves are computed via :class:`LensCalc`. Radial - critical curves are only computed when at least one radial critical-curve - area exceeds the grid pixel scale, avoiding spurious empty curves. + allowing callers to cache the curves across multiple plot calls. + + The algorithm used when *tc* is ``None`` is controlled by + ``general.critical_curves_method`` in ``visualize/general.yaml``: +/btw ok + - ``"zero_contour"`` *(default)* — uses ``jax_zero_contour``; no dense + grid needed, seed points found automatically via a coarse grid scan. + - ``"marching_squares"`` — evaluates eigen values on the full *grid* and + uses marching squares. Radial critical curves are only computed when at + least one radial critical-curve area exceeds the grid pixel scale. Parameters ---------- mass_obj - Any object understood by ``LensCalc.from_mass_obj`` (e.g. a - :class:`~autogalaxy.galaxy.galaxies.Galaxies` instance). + Any object understood by ``LensCalc.from_mass_obj``. grid : aa.type.Grid2DLike - The grid on which to evaluate the critical curves. + Evaluation grid (used only for the ``"marching_squares"`` path). tc : list or None - Pre-computed tangential critical curves. Pass ``None`` to trigger + Pre-computed tangential critical curves; ``None`` to trigger computation. rc : list or None - Pre-computed radial critical curves. Pass ``None`` to trigger - computation. + Pre-computed radial critical curves; ``None`` to trigger computation. Returns ------- @@ -326,9 +369,15 @@ def _critical_curves_from(mass_obj, grid, tc=None, rc=None): if tc is None: od = LensCalc.from_mass_obj(mass_obj) - tc = od.tangential_critical_curve_list_from(grid=grid) - rc_area = od.radial_critical_curve_area_list_from(grid=grid) - if any(area > grid.pixel_scale for area in rc_area): - rc = od.radial_critical_curve_list_from(grid=grid) + method = _critical_curves_method() + + if method == "zero_contour": + tc = od.tangential_critical_curve_list_via_zero_contour_from() + rc = od.radial_critical_curve_list_via_zero_contour_from() + else: + tc = od.tangential_critical_curve_list_from(grid=grid) + rc_area = od.radial_critical_curve_area_list_from(grid=grid) + if any(area > grid.pixel_scale for area in rc_area): + rc = od.radial_critical_curve_list_from(grid=grid) return tc, rc diff --git a/test_autogalaxy/plot/mat_wrap/test_visuals.py b/test_autogalaxy/plot/mat_wrap/test_visuals.py index f1f52922..780aea0f 100644 --- a/test_autogalaxy/plot/mat_wrap/test_visuals.py +++ b/test_autogalaxy/plot/mat_wrap/test_visuals.py @@ -39,14 +39,25 @@ def test__2d__caustics_from_mass_obj(gal_x1_mp, grid_2d_7x7): def test__mass_plotter__tangential_critical_curves(gal_x1_mp, grid_2d_7x7): + import numpy as np from autogalaxy.plot.plot_utils import _critical_curves_from tc, rc = _critical_curves_from(gal_x1_mp, grid_2d_7x7) + assert tc is not None + assert len(tc) > 0 + + # The default method is zero_contour, which traces the same locus as + # marching squares but with a different point density. Verify geometric + # consistency: mean radius must agree to within 5%. od = LensCalc.from_mass_obj(gal_x1_mp) expected_tc = od.tangential_critical_curve_list_from(grid=grid_2d_7x7) - assert (tc[0] == expected_tc[0]).all() + def _mean_radius(curve): + pts = np.array(curve) + return float(np.mean(np.sqrt(pts[:, 0] ** 2 + pts[:, 1] ** 2))) + + assert abs(_mean_radius(tc[0]) - _mean_radius(expected_tc[0])) / _mean_radius(expected_tc[0]) < 0.05 def test__mass_plotter__caustics_via_lens_calc(gal_x1_mp, grid_2d_7x7): From f07523aaf89a705878349668527e68458b1aae5a Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 1 Apr 2026 11:19:00 +0100 Subject: [PATCH 2/4] Make jax_zero_contour an optional dependency with a clear error message Add jax_zero_contour to pyproject.toml [optional] extras and wrap the import inside _critical_curve_list_via_zero_contour in a try/except so that the rest of the package imports cleanly on build servers that do not have the package installed. Co-Authored-By: Claude Sonnet 4.6 --- autogalaxy/operate/lens_calc.py | 8 +++++++- pyproject.toml | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/autogalaxy/operate/lens_calc.py b/autogalaxy/operate/lens_calc.py index 05ffc2f8..fcddbfb1 100644 --- a/autogalaxy/operate/lens_calc.py +++ b/autogalaxy/operate/lens_calc.py @@ -1097,7 +1097,13 @@ def _critical_curve_list_via_zero_contour( max_newton Maximum Newton iterations per step. """ - from jax_zero_contour import ZeroSolver + try: + from jax_zero_contour import ZeroSolver + except ModuleNotFoundError as exc: + raise ModuleNotFoundError( + "jax_zero_contour is required for zero-contour critical curve tracing. " + "Install it with: pip install jax_zero_contour" + ) from exc import jax.numpy as jnp if init_guess is None: diff --git a/pyproject.toml b/pyproject.toml index a4412111..0dd5268c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,7 @@ local_scheme = "no-local-version" [project.optional-dependencies] optional=[ "numba", + "jax_zero_contour", "pynufft", "ultranest==3.6.2", "zeus-mcmc==2.5.4", From 37db7749ad9c44e62f74f6bee1876c48360a80a0 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 1 Apr 2026 11:25:34 +0100 Subject: [PATCH 3/4] Close zero-contour critical curves by appending first point to last Appending pts[0] to the end of each path ensures the plotted critical curve and its derived caustic have no gap between the last and first point. Co-Authored-By: Claude Sonnet 4.6 --- autogalaxy/operate/lens_calc.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/autogalaxy/operate/lens_calc.py b/autogalaxy/operate/lens_calc.py index fcddbfb1..c05775e8 100644 --- a/autogalaxy/operate/lens_calc.py +++ b/autogalaxy/operate/lens_calc.py @@ -1118,11 +1118,13 @@ def _critical_curve_list_via_zero_contour( paths, _ = solver.zero_contour_finder(f, init_guess, delta=delta, N=N) paths = ZeroSolver.path_reduce(paths) - return [ - aa.Grid2DIrregular(values=np.array(path)) - for path in paths["path"] - if len(path) > 1 - ] + closed_paths = [] + for path in paths["path"]: + if len(path) > 1: + pts = np.array(path) + pts = np.vstack([pts, pts[0]]) # close the curve + closed_paths.append(aa.Grid2DIrregular(values=pts)) + return closed_paths def tangential_critical_curve_list_via_zero_contour_from( self, From 768d3575b9707ad63c9de1d885436227afc38e3c Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 1 Apr 2026 11:47:45 +0100 Subject: [PATCH 4/4] Add Jax-Zero-Contour citation to docs/general/citations.rst Co-Authored-By: Claude Sonnet 4.6 --- docs/general/citations.rst | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/docs/general/citations.rst b/docs/general/citations.rst index 27df530f..4aaa3222 100644 --- a/docs/general/citations.rst +++ b/docs/general/citations.rst @@ -18,6 +18,29 @@ You should also specify the non-linear search(es) you use in your analysis (e.g. the main body of text, and delete as appropriate any packages your analysis did not use. The citations.bib file includes the citation key for all of these projects. +Jax-Zero-Contour +---------------- + +If you use the zero-contour method for critical curve and caustic computation (the default in +``visualize/general.yaml`` via ``critical_curves_method: zero_contour``), please cite the +``Jax-Zero-Contour`` package by Coleman Krawczyk: + +.. code-block:: bibtex + + @software{coleman_krawczyk_2025_15730415, + author = {Coleman Krawczyk}, + title = {CKrawczyk/Jax-Zero-Contour: Version 2.0.0}, + month = jun, + year = 2025, + publisher = {Zenodo}, + version = {v2.0.0}, + doi = {10.5281/zenodo.15730415}, + url = {https://doi.org/10.5281/zenodo.15730415}, + } + +The package is available at https://github.com/CKrawczyk/Jax-Zero-Contour and archived at +https://doi.org/10.5281/zenodo.15730415. + Dynesty -------