diff --git a/autogalaxy/cosmology/model.py b/autogalaxy/cosmology/model.py index 98e13847..f3be5dfa 100644 --- a/autogalaxy/cosmology/model.py +++ b/autogalaxy/cosmology/model.py @@ -574,15 +574,18 @@ def angular_diameter_distance_kpc_z1z2( - Dark energy equation of state constant w0 Results are cached keyed on (z1, z2, n_steps) to avoid redundant - Simpson integrations for repeated redshift pairs. - """ - cache = getattr(self, "_dist_cache", None) - if cache is None: - cache = {} - self._dist_cache = cache - key = (float(z1), float(z2), n_steps) - if key in cache: - return xp.asarray(cache[key]) + Simpson integrations for repeated redshift pairs. Cache is only used + on the NumPy path — JAX abstract tracers cannot be converted to float + and JAX handles its own caching via JIT compilation. + """ + if xp is np: + cache = getattr(self, "_dist_cache", None) + if cache is None: + cache = {} + self._dist_cache = cache + key = (float(z1), float(z2), n_steps) + if key in cache: + return cache[key] # Ensure odd number of samples for Simpson (safe: n_steps is a Python int) if (n_steps % 2) == 0: @@ -650,7 +653,8 @@ def E_local(z): result = xp.where(same, xp.asarray(0.0), Da_kpc) - self._dist_cache[key] = float(result) + if xp is np: + self._dist_cache[key] = result return result