From 632ae8292bf4deddaa58e45dba580a67c3826646 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Sun, 12 Apr 2026 15:58:08 +0100 Subject: [PATCH] build: raise astropy cap, update JAX Partial imports for 0.5+ compat Raise astropy cap from <=6.1.2 to <=7.2.0. Replace deprecated jax.tree_util.Partial with functools.partial in jax_utils.py, and add fallback import in lens_calc.py for forward compatibility. Part of cross-ecosystem dependency sweep (PyAutoLabs/PyAutoConf#87). Co-Authored-By: Claude Opus 4.6 --- autogalaxy/operate/lens_calc.py | 5 ++++- autogalaxy/profiles/mass/total/jax_utils.py | 2 +- pyproject.toml | 2 +- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/autogalaxy/operate/lens_calc.py b/autogalaxy/operate/lens_calc.py index 50d90a91d..bb839e29a 100644 --- a/autogalaxy/operate/lens_calc.py +++ b/autogalaxy/operate/lens_calc.py @@ -983,7 +983,10 @@ def _make_eigen_fn(self, kind: str, pixel_scales=(0.05, 0.05)): """ import jax import jax.numpy as jnp - from jax.tree_util import Partial + try: + from jax.tree_util import Partial + except ImportError: + from functools import partial as Partial # Capture as local names so the closure holds no `self` reference. # ZeroSolver.zero_contour_finder is jit-compiled with `f` as a diff --git a/autogalaxy/profiles/mass/total/jax_utils.py b/autogalaxy/profiles/mass/total/jax_utils.py index d04691ea2..94c2f1399 100644 --- a/autogalaxy/profiles/mass/total/jax_utils.py +++ b/autogalaxy/profiles/mass/total/jax_utils.py @@ -32,7 +32,7 @@ def omega(eiphi, slope, factor, n_terms=20, xp=np): be sufficient most of the time) """ - from jax.tree_util import Partial as partial + from functools import partial import jax scan = jax.jit(jax.lax.scan, static_argnames=("length", "reverse", "unroll")) diff --git a/pyproject.toml b/pyproject.toml index 1bac85271..dcfc8dc4b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "autofit", "autoarray", "colossus==1.3.1", - "astropy>=5.0,<=6.1.2", + "astropy>=5.0,<=7.2.0", "nautilus-sampler==1.0.5" ]