Skip to content

Fix JAX jit boundary in LensCalc + document decorator/JAX patterns#291

Merged
Jammy2211 merged 1 commit intomainfrom
feature/jax_decorator_bypass
Mar 6, 2026
Merged

Fix JAX jit boundary in LensCalc + document decorator/JAX patterns#291
Jammy2211 merged 1 commit intomainfrom
feature/jax_decorator_bypass

Conversation

@Jammy2211
Copy link
Copy Markdown
Collaborator

Summary

  • Fixes a TypeError that occurred when any of the six hessian-derived LensCalc methods were called directly inside jax.jit
  • Updates CLAUDE.md with decorator system overview and the if xp is np: guard pattern

Problem

Autoarray types (ArrayIrregular, ShearYX2DIrregular, Array2D) are not registered as JAX pytrees. The following methods unconditionally wrapped their return values in these types, causing TypeError: ... is not a valid JAX type when used as the output of a jax.jit-compiled function:

  • convergence_2d_via_hessian_from
  • shear_yx_2d_via_hessian_from
  • magnification_2d_via_hessian_from
  • magnification_2d_from
  • tangential_eigen_value_from
  • radial_eigen_value_from

Fix

Each method now guards autoarray wrapping with if xp is np:, returning the wrapper on the numpy path and a raw jax.Array on the JAX path. tangential_eigen_value_from and radial_eigen_value_from also compute shear magnitudes via xp.sqrt on the JAX path (since ShearYX2DIrregular.magnitudes uses np.sqrt directly and is not usable on a raw jax array).

🤖 Generated with Claude Code

All six hessian-derived LensCalc methods now guard autoarray wrapping
with `if xp is np:` so they return a raw jax.Array on the JAX path.
This allows them to be called directly inside jax.jit without the
TypeError that occurred when an ArrayIrregular or Array2D was returned
as the JIT output.

Methods fixed:
- convergence_2d_via_hessian_from
- shear_yx_2d_via_hessian_from
- magnification_2d_via_hessian_from
- magnification_2d_from
- tangential_eigen_value_from (also adds jnp.sqrt for shear magnitudes)
- radial_eigen_value_from (same)

CLAUDE.md updated with decorator system overview and the if-xp-is-np
guard pattern for functions at the jax.jit boundary.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
@Jammy2211 Jammy2211 requested a review from Copilot March 6, 2026 16:09
@Jammy2211 Jammy2211 merged commit 5e06742 into main Mar 6, 2026
11 checks passed
@Jammy2211 Jammy2211 deleted the feature/jax_decorator_bypass branch March 6, 2026 16:15
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Fixes JAX jax.jit boundary failures in LensCalc by avoiding returning non-pytree autoarray wrapper types on the JAX backend, and documents the project’s decorator/JAX patterns for future contributors.

Changes:

  • Guard six Hessian-derived LensCalc methods so NumPy returns autoarray wrappers while JAX returns raw arrays.
  • Add guidance to CLAUDE.md on autoarray decorators, xp backend usage, and the if xp is np: guard pattern.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 8 comments.

File Description
autogalaxy/operate/lens_calc.py Adds if xp is np: guards to avoid returning autoarray wrapper types from jax.jit-compiled functions.
CLAUDE.md Documents decorator stacking expectations and JAX/autoarray wrapper constraints at the jax.jit boundary.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines +553 to +555
if xp is np:
return aa.ArrayIrregular(values=1.0 / det_A)
return 1.0 / det_A
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

magnification_2d_via_hessian_from is annotated as returning aa.ArrayIrregular, but returns a raw array when xp is not np. Update the return type annotation (or add overloads) to match the conditional wrapper behavior.

Copilot uses AI. Check for mistakes.
Comment on lines +250 to +255
if xp is np:
return aa.Array2D(
values=1 - convergence - shear_yx.magnitudes, mask=grid.mask
)
magnitudes = xp.sqrt(shear_yx[:, 0] ** 2 + shear_yx[:, 1] ** 2)
return 1 - convergence - magnitudes
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is specifically to avoid TypeError at the jax.jit boundary, but there’s no regression test exercising these methods under jax.jit. Since JAX is already available in the test suite (test_autogalaxy/conftest.py imports jax.numpy), consider adding a small unit test that jits a lambda calling each updated method with xp=jnp and asserts it runs/returns a JAX array.

Copilot uses AI. Check for mistakes.
Comment on lines +215 to +216
Prefer simple shell commands.
Avoid chaining with && or pipes. No newline at end of file
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The guidance says to avoid chaining with pipes, but the example immediately above uses a pipe (find ... | xargs ...). Either adjust the recommendation (e.g. avoid complex pipelines) or update the example so the instructions are internally consistent.

Suggested change
Prefer simple shell commands.
Avoid chaining with && or pipes.
Prefer simple, readable shell commands.
Avoid complex command chains with multiple && operators or many pipes; a single straightforward pipe (as above) is fine when it improves clarity.

Copilot uses AI. Check for mistakes.
Comment on lines +250 to +255
if xp is np:
return aa.Array2D(
values=1 - convergence - shear_yx.magnitudes, mask=grid.mask
)
magnitudes = xp.sqrt(shear_yx[:, 0] ** 2 + shear_yx[:, 1] ** 2)
return 1 - convergence - magnitudes
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These methods now return a raw JAX array when xp is not np, but the return annotation still claims aa.Array2D. This mismatch can mislead callers and type checkers; consider updating the annotation to a Union[...] (or overloads) that reflects both the NumPy-wrapped and JAX-raw return types.

Copilot uses AI. Check for mistakes.
Comment on lines +275 to +280
if xp is np:
return aa.Array2D(
values=1 - convergence + shear.magnitudes, mask=grid.mask
)
magnitudes = xp.sqrt(shear[:, 0] ** 2 + shear[:, 1] ** 2)
return 1 - convergence + magnitudes
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Return type annotation is aa.Array2D, but this function now returns a raw array on the JAX path (xp is not np). Update the annotation (e.g., Union[aa.Array2D, <array type>] / overloads) so the public API matches actual behavior.

Copilot uses AI. Check for mistakes.
Comment on lines +301 to +303
if xp is np:
return aa.Array2D(values=1 / det_A, mask=grid.mask)
return 1 / det_A
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

magnification_2d_from is annotated to return aa.Array2D, but returns a raw array when xp is not np. Please update the return type annotation (and/or add overloads) to reflect the conditional wrapper behavior.

Copilot uses AI. Check for mistakes.
Comment on lines +480 to +482
if xp is np:
return aa.ArrayIrregular(values=convergence)
return convergence
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

convergence_2d_via_hessian_from is annotated as returning aa.ArrayIrregular, but returns a raw array on the JAX path. Adjust the return type annotation (e.g., Union[aa.ArrayIrregular, <array type>] / overloads) so callers can rely on the signature.

Copilot uses AI. Check for mistakes.
Comment on lines +523 to +525
if xp is np:
return ShearYX2DIrregular(values=shear_yx_2d, grid=grid)
return shear_yx_2d
Copy link

Copilot AI Mar 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shear_yx_2d_via_hessian_from is annotated to return ShearYX2DIrregular, but returns a raw (N, 2) array when xp is not np. Please update the return annotation (or add overloads) to reflect the JAX-compatible return type.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants