Conversation
Jammy2211
left a comment
There was a problem hiding this comment.
- Delete commented out code.
- Make sure there are no module-level imports of JAX.
- Do anything from copilot you think sounds important.
Then good to merge!
There was a problem hiding this comment.
Pull request overview
This pull request makes the Gaussian mass profile JAX-compatible by replacing NumPy-specific operations with a generic xp parameter that can accept either NumPy or JAX arrays. The main changes involve refactoring mathematical operations to work with both array libraries and implementing a custom JAX-compatible Faddeeva function (wofz).
Changes:
- Replaced
npcalls withxpparameter throughout gaussian.py for JAX compatibility - Implemented custom JAX-compatible wofz (Faddeeva function) to replace scipy.special.wofz
- Commented out scipy-dependent deflections_2d_via_integral_from method and related tests
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 12 comments.
| File | Description |
|---|---|
| test_autogalaxy/profiles/mass/stellar/test_gaussian.py | Commented out tests for scipy-dependent integral methods that are not JAX-compatible |
| autogalaxy/profiles/mass/stellar/gaussian.py | Converted np operations to xp, implemented custom wofz function, refactored zeta_from for JAX compatibility, commented out scipy-dependent methods |
| autogalaxy/init.py | Commented out import from autoarray.dataset.interferometer.w_tilde |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -6,9 +6,9 @@ | |||
|
|
|||
| from autoconf.dictable import from_dict, from_json, output_to_json, to_dict | |||
| from autoarray.dataset import preprocess # noqa | |||
There was a problem hiding this comment.
The import from autoarray.dataset.interferometer.w_tilde has been commented out without explanation. If this import is no longer needed or causes issues, it should be removed entirely rather than commented out. If it's needed but temporarily causing issues, a TODO comment should explain why it's commented and what needs to be done to re-enable it.
| from autoarray.dataset import preprocess # noqa | |
| from autoarray.dataset import preprocess # noqa | |
| # TODO: This import is temporarily disabled because autoarray.dataset.interferometer.w_tilde | |
| # TODO: may not be available or compatible in all environments. Re-enable once the module | |
| # TODO: and load_curvature_preload_if_compatible are confirmed to be stable and required. |
This pull request makes significant changes to the implementation of the
Gaussianmass profile inautogalaxy, primarily focusing on improving compatibility with array libraries like JAX, simplifying the code, and deprecating the integral-based deflection calculation in favor of the analytic approach. The most important changes are grouped below:Core algorithm and compatibility improvements
zeta_frommethod inGaussianto a new, more efficient and JAX-compatible implementation, including a customwofz(Faddeeva function) for use with both NumPy and JAX backends. The new implementation replaces the previous approach that usedscipy.special.wofzdirectly, and ensures symmetry and compatibility with autodiff and JIT compilation.Gaussianto use thexp(array module) argument instead of hardcodednp, ensuring compatibility with both NumPy and JAX throughout the profile calculations. [1] [2] [3]Removal and deprecation of legacy/integral code
deflections_2d_via_integral_fromand its associated test cases, indicating a shift to exclusively using the analytic solution for deflection calculations. This simplifies the codebase and reduces maintenance overhead. [1] [2]deflections_yx_2d_frommethod to remove a now-unnecessary zero-intensity check and commented out legacy code, further simplifying the logic and relying solely on the analytic approach.Minor code and import cleanups
autogalaxy/__init__.pyto comment out unused imports related to interferometer curvature preloads.These changes collectively improve the maintainability, performance, and future-proofing of the
Gaussianmass profile implementation inautogalaxy.gaussian.py in mass profiles made to be used in jax.Commented out the
deflections_2d_via_integral_fromand the tests in which it is contained.Changed
defl = self.deflections_2d_via_analytic_from(grid=grid, xp=xp, **kwargs) return xp.where(self.intensity == 0.0, xp.zeros_like(defl), defl)back to
return self.deflections_2d_via_analytic_from(grid=grid, xp=xp, **kwargs)as the tests did not come through:
FAILED test_autogalaxy/profiles/test_light_and_mass_profiles.py::test__gaussian - AttributeError: 'numpy.ndarray' object has no attribute 'array'
FAILED test_autogalaxy/profiles/test_light_and_mass_profiles.py::test__gaussian_gradient - AttributeError: 'numpy.ndarray' object has no attribute 'array'