diff --git a/autogalaxy/operate/lens_calc.py b/autogalaxy/operate/lens_calc.py index 50d90a91..bb839e29 100644 --- a/autogalaxy/operate/lens_calc.py +++ b/autogalaxy/operate/lens_calc.py @@ -983,7 +983,10 @@ def _make_eigen_fn(self, kind: str, pixel_scales=(0.05, 0.05)): """ import jax import jax.numpy as jnp - from jax.tree_util import Partial + try: + from jax.tree_util import Partial + except ImportError: + from functools import partial as Partial # Capture as local names so the closure holds no `self` reference. # ZeroSolver.zero_contour_finder is jit-compiled with `f` as a diff --git a/autogalaxy/profiles/mass/total/jax_utils.py b/autogalaxy/profiles/mass/total/jax_utils.py index d04691ea..94c2f139 100644 --- a/autogalaxy/profiles/mass/total/jax_utils.py +++ b/autogalaxy/profiles/mass/total/jax_utils.py @@ -32,7 +32,7 @@ def omega(eiphi, slope, factor, n_terms=20, xp=np): be sufficient most of the time) """ - from jax.tree_util import Partial as partial + from functools import partial import jax scan = jax.jit(jax.lax.scan, static_argnames=("length", "reverse", "unroll")) diff --git a/pyproject.toml b/pyproject.toml index 1bac8527..dcfc8dc4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "autofit", "autoarray", "colossus==1.3.1", - "astropy>=5.0,<=6.1.2", + "astropy>=5.0,<=7.2.0", "nautilus-sampler==1.0.5" ]