Conversation
There was a problem hiding this comment.
Pull request overview
This PR improves NumPy/JAX backend flexibility by consistently propagating the xp (array module) argument through galaxy image / blurred-image / visibilities dictionary helpers and their primary call sites in imaging and interferometer fits.
Changes:
- Add/propagate
xpthroughgalaxy_image_2d_dict_from,galaxy_blurred_image_2d_dict_from, andgalaxy_visibilities_dict_frompathways. - Update
FitImaging/FitInterferometerto passself._xpinto galaxy dict computations. - Minor signature/formatting cleanup (e.g., inversion fit init signature formatting).
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
| autogalaxy/operate/image.py | Adds xp forwarding for galaxy image/blurred-image/visibilities dictionary helpers (but currently misses forwarding xp into the transformer call). |
| autogalaxy/interferometer/fit_interferometer.py | Passes self._xp into galaxy image/visibilities dict computations for backend consistency. |
| autogalaxy/imaging/fit_imaging.py | Passes self._xp into galaxy image/blurred-image dict computations for backend consistency. |
| autogalaxy/galaxy/galaxies.py | Reformats galaxy_image_2d_dict_from signature; return type annotation remains incorrect and should be fixed while touched. |
| autogalaxy/abstract_fit.py | Constructor signature formatting cleanup; minor line wrapping. |
Comments suppressed due to low confidence (2)
autogalaxy/imaging/fit_imaging.py:20
- This class does not call AbstractFitInversion.init during initialization. (FitImaging.init may be missing a call to a base class init)
class FitImaging(aa.FitImaging, AbstractFitInversion):
autogalaxy/interferometer/fit_interferometer.py:15
- This class does not call AbstractFitInversion.init during initialization. (FitInterferometer.init may be missing a call to a base class init)
class FitInterferometer(aa.FitInterferometer, AbstractFitInversion):
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| """ | ||
|
|
||
| galaxy_image_2d_dict = self.galaxy_image_2d_dict_from(grid=grid) | ||
| galaxy_image_2d_dict = self.galaxy_image_2d_dict_from(grid=grid, xp=xp) |
There was a problem hiding this comment.
galaxy_visibilities_dict_from now correctly forwards xp into galaxy_image_2d_dict_from, but the subsequent call to transformer.visibilities_from(...) still omits xp. This will break JAX / non-numpy backends (and is inconsistent with visibilities_list_from / visibilities_from above). Pass xp=xp into transformer.visibilities_from here as well.
| grid: aa.type.Grid2DLike, | ||
| xp=np, | ||
| operated_only: Optional[bool] = None, | ||
| ) -> {Galaxy: np.ndarray}: |
There was a problem hiding this comment.
The return type annotation -> {Galaxy: np.ndarray} is a dict literal, not a typing annotation, and it doesn’t match the actual returned values (which are aa.Array2D instances from image_2d_list_from). Replace it with an appropriate typing type (e.g. Dict[Galaxy, aa.Array2D]) to avoid confusing/incorrect type information.
| ) -> {Galaxy: np.ndarray}: | |
| ) -> Dict[Galaxy, aa.Array2D]: |
| grid=grid, operated_only=False, xp=xp | ||
| ) | ||
|
|
||
| galaxy_blurring_image_2d_not_operated_dict = self.galaxy_image_2d_dict_from( | ||
| grid=blurring_grid, operated_only=False | ||
| grid=blurring_grid, operated_only=False, xp=xp |
There was a problem hiding this comment.
The new xp propagation in the galaxy-dict helpers isn’t covered by tests for a non-numpy backend. Consider extending the existing test_autogalaxy/operate/test_image.py coverage to call these methods with xp=jax.numpy (or self._xp) and assert the outputs match the corresponding *_list_from(..., xp=...) results / are computed via the same backend.
This pull request mainly focuses on improving consistency and flexibility in handling the
xp(array module, e.g., numpy or jax.numpy) argument throughout the galaxy modeling codebase. The changes ensure that the correct array module is consistently passed through all relevant methods, which is important for supporting both numpy and jax backends.Key changes include:
Consistent xp Argument Propagation:
galaxy_image_2d_dict_from,galaxy_blurred_image_2d_dict_from, andgalaxy_visibilities_dict_frommethods inautogalaxy/operate/image.pyto accept and pass thexpargument, ensuring the correct array module is used throughout the computation. [1] [2] [3] [4]autogalaxy/imaging/fit_imaging.pyandautogalaxy/interferometer/fit_interferometer.pyto explicitly provide thexpargument, usingself._xpwhere appropriate. [1] [2] [3] [4]Constructor and Method Signature Cleanups:
__init__signatures for inversion fit classes to match the new argument style, improving readability and consistency. [1] [2] [3]galaxy_image_2d_dict_frominautogalaxy/galaxy/galaxies.pyto include thexpargument, aligning with the new interface.Minor Formatting Improvements:
These changes collectively improve backend flexibility and maintainability, especially for environments that may use different array computation libraries.