From c70fe32a27bb9d4ce08cb97a69551240c41369f0 Mon Sep 17 00:00:00 2001 From: Jammy2211 Date: Tue, 7 Apr 2026 19:47:39 +0100 Subject: [PATCH] Fix cosmology cache breaking JAX JIT tracing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The dist cache used float(result) to store values, which fails during JAX JIT compilation because abstract tracers cannot be converted to float. Guard both cache lookup and store with `if xp is np:` so the cache is only used on the NumPy path — JAX handles its own caching via JIT compilation. Co-Authored-By: Claude Haiku 4.5 --- autogalaxy/cosmology/model.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/autogalaxy/cosmology/model.py b/autogalaxy/cosmology/model.py index 98e13847..f3be5dfa 100644 --- a/autogalaxy/cosmology/model.py +++ b/autogalaxy/cosmology/model.py @@ -574,15 +574,18 @@ def angular_diameter_distance_kpc_z1z2( - Dark energy equation of state constant w0 Results are cached keyed on (z1, z2, n_steps) to avoid redundant - Simpson integrations for repeated redshift pairs. - """ - cache = getattr(self, "_dist_cache", None) - if cache is None: - cache = {} - self._dist_cache = cache - key = (float(z1), float(z2), n_steps) - if key in cache: - return xp.asarray(cache[key]) + Simpson integrations for repeated redshift pairs. Cache is only used + on the NumPy path — JAX abstract tracers cannot be converted to float + and JAX handles its own caching via JIT compilation. + """ + if xp is np: + cache = getattr(self, "_dist_cache", None) + if cache is None: + cache = {} + self._dist_cache = cache + key = (float(z1), float(z2), n_steps) + if key in cache: + return cache[key] # Ensure odd number of samples for Simpson (safe: n_steps is a Python int) if (n_steps % 2) == 0: @@ -650,7 +653,8 @@ def E_local(z): result = xp.where(same, xp.asarray(0.0), Da_kpc) - self._dist_cache[key] = float(result) + if xp is np: + self._dist_cache[key] = result return result