Skip to content

Fix cosmology distance cache breaking JAX JIT tracing#334

Merged
Jammy2211 merged 1 commit intomainfrom
feature/fix-cosmology-jax-cache
Apr 7, 2026
Merged

Fix cosmology distance cache breaking JAX JIT tracing#334
Jammy2211 merged 1 commit intomainfrom
feature/fix-cosmology-jax-cache

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

  • Guards the angular_diameter_distance_kpc_z1z2 distance cache with if xp is np: so it is skipped entirely on the JAX path
  • Previously float(result) was called on a JAX abstract tracer during JIT compilation, raising ConcretizationTypeError
  • Also removed xp.asarray() wrapping on cache hit return (result is already a numpy array)
  • JAX handles its own caching via JIT compilation so no caching is needed on that path

Test plan

  • Run PyAutoGalaxy tests: python -m pytest test_autogalaxy/
  • Verify real JAX fitting no longer raises ConcretizationTypeError in cosmology
  • Verify numpy path still caches correctly (repeated calls return same object)

🤖 Generated with Claude Code

The dist cache used float(result) to store values, which fails during
JAX JIT compilation because abstract tracers cannot be converted to
float. Guard both cache lookup and store with `if xp is np:` so the
cache is only used on the NumPy path — JAX handles its own caching
via JIT compilation.

Co-Authored-By: Claude Haiku 4.5 <noreply@anthropic.com>
@Jammy2211 Jammy2211 merged commit da64c5d into main Apr 7, 2026
2 checks passed
@Jammy2211 Jammy2211 deleted the feature/fix-cosmology-jax-cache branch April 7, 2026 19:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant