From fcf875e0c3553c98263a981e324b9e97931c7dc9 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Wed, 1 Apr 2026 19:28:01 +0100 Subject: [PATCH] Fix JAX array leak from PointSolver.solve() into Grid2DIrregular When the solver uses a JAX backend, the final boolean-indexed solution was a JAX DeviceArray. Wrapping it in Grid2DIrregular without conversion caused downstream np.array() calls in the visualizer to raise: ValueError: object __array__ method not producing an array Fix: convert solution to numpy via np.asarray() before constructing the return Grid2DIrregular. Safe because solve() is never called inside jax.jit (variable-length boolean indexing prevents it). Also updates the Returns docstring to document the numpy-backed guarantee. Co-Authored-By: Claude Sonnet 4.6 --- CLAUDE.md | 32 ++++++++++++++++++--------- autolens/point/solver/point_solver.py | 6 +++-- 2 files changed, 25 insertions(+), 13 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 9a93253ec..a1e338899 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -9,21 +9,31 @@ This file provides guidance to Claude Code (claude.ai/code) when working with co pip install -e ".[dev]" ``` -### Run Tests -```bash -# All tests -python -m pytest test_autolens/ +### Run Tests +```bash +# All tests +python -m pytest test_autolens/ # Single test file python -m pytest test_autolens/lens/test_tracer.py -# With output -python -m pytest test_autolens/imaging/test_fit_imaging.py -s -``` - -### Formatting -```bash -black autolens/ +# With output +python -m pytest test_autolens/imaging/test_fit_imaging.py -s +``` + +### Codex / sandboxed runs + +When running Python from Codex or any restricted environment, set writable cache directories so `numba` and `matplotlib` do not fail on unwritable home or source-tree paths: + +```bash +NUMBA_CACHE_DIR=/tmp/numba_cache MPLCONFIGDIR=/tmp/matplotlib python -m pytest test_autolens/ +``` + +This workspace is often imported from `/mnt/c/...` and Codex may not be able to write to module `__pycache__` directories or `/home/jammy/.cache`, which can cause import-time `numba` caching failures without this override. + +### Formatting +```bash +black autolens/ ``` ## Architecture diff --git a/autolens/point/solver/point_solver.py b/autolens/point/solver/point_solver.py index 6a5cad4b6..79d94f4f6 100644 --- a/autolens/point/solver/point_solver.py +++ b/autolens/point/solver/point_solver.py @@ -19,6 +19,7 @@ import logging from typing import Tuple, Optional +import numpy as np import autoarray as aa from autoarray.structures.triangles.shape import Point @@ -66,7 +67,8 @@ def solve( Returns ------- - A list of image plane coordinates that are traced to the source plane coordinate. + A ``Grid2DIrregular`` of image-plane coordinates, always numpy-backed even when the + solver uses a JAX backend internally. """ kept_triangles = super().solve_triangles( tracer=tracer, @@ -90,4 +92,4 @@ def solve( solution = solution[~self._xp.isinf(solution).any(axis=1)] - return aa.Grid2DIrregular(solution) + return aa.Grid2DIrregular(np.asarray(solution))