diff --git a/autogalaxy/__init__.py b/autogalaxy/__init__.py index ba666865d..947855403 100644 --- a/autogalaxy/__init__.py +++ b/autogalaxy/__init__.py @@ -22,8 +22,6 @@ from autoarray.inversion.pixelization.border_relocator import BorderRelocator from autoarray.mask.mask_1d import Mask1D # noqa from autoarray.mask.mask_2d import Mask2D # noqa -from autoarray.operators.convolver import Convolver # noqa -from autoarray.operators.convolver import Convolver # noqa from autoarray.operators.transformer import TransformerDFT # noqa from autoarray.operators.transformer import TransformerNUFFT # noqa from autoarray.layout.layout import Layout2D # noqa diff --git a/autogalaxy/analysis/analysis/analysis.py b/autogalaxy/analysis/analysis/analysis.py index ae583b03c..084af3104 100644 --- a/autogalaxy/analysis/analysis/analysis.py +++ b/autogalaxy/analysis/analysis/analysis.py @@ -192,18 +192,18 @@ def profile_log_likelihood_function( try: info_dict["image_pixels"] = self.dataset.grids.lp.shape_slim - info_dict[ - "sub_total_light_profiles" - ] = self.dataset.grids.lp.over_sampler.sub_total + info_dict["sub_total_light_profiles"] = ( + self.dataset.grids.lp.over_sampler.sub_total + ) except AttributeError: pass if fit.model_obj.has(cls=aa.Pixelization): info_dict["use_w_tilde"] = fit.inversion.settings.use_w_tilde try: - info_dict[ - "sub_total_pixelization" - ] = self.dataset.grids.pixelization.over_sampler.sub_total + info_dict["sub_total_pixelization"] = ( + self.dataset.grids.pixelization.over_sampler.sub_total + ) except AttributeError: pass info_dict["use_positive_only_solver"] = ( diff --git a/autogalaxy/analysis/chaining_util.py b/autogalaxy/analysis/chaining_util.py index 930b6a5f6..95306e45f 100644 --- a/autogalaxy/analysis/chaining_util.py +++ b/autogalaxy/analysis/chaining_util.py @@ -215,39 +215,29 @@ def extra_galaxies_from( for extra_galaxy_index in range(len(result.instance.extra_galaxies)): if hasattr(result.instance.extra_galaxies[extra_galaxy_index], "mass"): - extra_galaxies[ - extra_galaxy_index - ].mass.centre = result.instance.extra_galaxies[ - extra_galaxy_index - ].mass.centre - extra_galaxies[ - extra_galaxy_index - ].mass.einstein_radius = result.model.extra_galaxies[ - extra_galaxy_index - ].mass.einstein_radius + extra_galaxies[extra_galaxy_index].mass.centre = ( + result.instance.extra_galaxies[extra_galaxy_index].mass.centre + ) + extra_galaxies[extra_galaxy_index].mass.einstein_radius = ( + result.model.extra_galaxies[extra_galaxy_index].mass.einstein_radius + ) if free_centre: - extra_galaxies[ - extra_galaxy_index - ].mass.centre = result.model.extra_galaxies[ - extra_galaxy_index - ].mass.centre + extra_galaxies[extra_galaxy_index].mass.centre = ( + result.model.extra_galaxies[extra_galaxy_index].mass.centre + ) elif light_as_model: extra_galaxies = result.instance.extra_galaxies.as_model((LightProfile,)) for extra_galaxy_index in range(len(result.instance.extra_galaxies)): if extra_galaxies[extra_galaxy_index].bulge is not None: - extra_galaxies[ - extra_galaxy_index - ].bulge.centre = result.instance.extra_galaxies[ - extra_galaxy_index - ].bulge.centre + extra_galaxies[extra_galaxy_index].bulge.centre = ( + result.instance.extra_galaxies[extra_galaxy_index].bulge.centre + ) if free_centre: - extra_galaxies[ - extra_galaxy_index - ].bulge.centre = result.model.extra_galaxies[ - extra_galaxy_index - ].bulge.centre + extra_galaxies[extra_galaxy_index].bulge.centre = ( + result.model.extra_galaxies[extra_galaxy_index].bulge.centre + ) else: extra_galaxies = result.instance.extra_galaxies.as_model(()) diff --git a/autogalaxy/convert.py b/autogalaxy/convert.py index 4c1ae9299..80488824e 100644 --- a/autogalaxy/convert.py +++ b/autogalaxy/convert.py @@ -66,11 +66,7 @@ def axis_ratio_and_angle_from(ell_comps: Tuple[float, float]) -> Tuple[float, fl angle *= 180.0 / np.pi if use_jax: - angle = jax.lax.select( - angle < -45, - angle + 180, - angle - ) + angle = jax.lax.select(angle < -45, angle + 180, angle) else: if abs(angle) > 45 and angle < 0: angle += 180 diff --git a/autogalaxy/galaxy/to_inversion.py b/autogalaxy/galaxy/to_inversion.py index 67869b23c..00f47782c 100644 --- a/autogalaxy/galaxy/to_inversion.py +++ b/autogalaxy/galaxy/to_inversion.py @@ -81,20 +81,20 @@ def __init__( self.run_time_dict = run_time_dict @property - def convolver(self) -> Optional[aa.Convolver]: + def psf(self) -> Optional[aa.Kernel2D]: """ - Returns the convolver of the imaging dataset, if the inversion is performed on an imaging dataset. + Returns the PSF of the imaging dataset, if the inversion is performed on an imaging dataset. The `GalaxiesToInversion` class acts as an interface between the dataset and inversion module for - both imaging and interferometer datasets. Only imaging datasets have a convolver, thus this property - ensures that for an interferometer dataset code which references a convolver does not raise an error. + both imaging and interferometer datasets. Only imaging datasets have a PSF, thus this property + ensures that for an interferometer dataset code which references a PSF does not raise an error. Returns ------- - The convolver of the imaging dataset, if it is an imaging dataset. + The psf of the imaging dataset, if it is an imaging dataset. """ try: - return self.dataset.convolver + return self.dataset.psf except AttributeError: return None @@ -310,7 +310,7 @@ def cls_light_profile_func_list_galaxy_dict_from( lp_linear_func = LightProfileLinearObjFuncList( grid=self.dataset.grids.lp, blurring_grid=self.dataset.grids.blurring, - convolver=self.dataset.convolver, + psf=self.dataset.psf, light_profile_list=light_profile_list, regularization=light_profile.regularization, ) diff --git a/autogalaxy/imaging/fit_imaging.py b/autogalaxy/imaging/fit_imaging.py index da96317bb..8371acc76 100644 --- a/autogalaxy/imaging/fit_imaging.py +++ b/autogalaxy/imaging/fit_imaging.py @@ -103,7 +103,7 @@ def blurred_image(self) -> aa.Array2D: return self.galaxies.blurred_image_2d_from( grid=self.grids.lp, - convolver=self.dataset.convolver, + psf=self.dataset.psf, blurring_grid=self.grids.blurring, ) @@ -120,7 +120,7 @@ def galaxies_to_inversion(self) -> GalaxiesToInversion: data=self.profile_subtracted_image, noise_map=self.noise_map, grids=self.grids, - convolver=self.dataset.convolver, + psf=self.dataset.psf, w_tilde=self.w_tilde, ) @@ -179,7 +179,7 @@ def galaxy_model_image_dict(self) -> Dict[Galaxy, np.ndarray]: galaxy_blurred_image_2d_dict = self.galaxies.galaxy_blurred_image_2d_dict_from( grid=self.grids.lp, - convolver=self.dataset.convolver, + psf=self.dataset.psf, blurring_grid=self.grids.blurring, ) diff --git a/autogalaxy/operate/deflections.py b/autogalaxy/operate/deflections.py index ef60c45fa..f1ef3a89f 100644 --- a/autogalaxy/operate/deflections.py +++ b/autogalaxy/operate/deflections.py @@ -99,16 +99,10 @@ def one_step(r, _, theta, fun, fun_dr): @partial(jit, static_argnums=(4,)) def step_r(r, theta, fun, fun_dr, N=20): one_step_partial = jax.tree_util.Partial( - one_step, - theta=theta, - fun=fun, - fun_dr=fun_dr + one_step, theta=theta, fun=fun, fun_dr=fun_dr ) new_r = jax.lax.scan(one_step_partial, r, xs=np.arange(N))[0] - return np.stack([ - new_r * np.sin(theta), - new_r * np.cos(theta) - ]).T + return np.stack([new_r * np.sin(theta), new_r * np.cos(theta)]).T class OperateDeflections: @@ -129,7 +123,7 @@ class OperateDeflections: def deflections_yx_2d_from(self, grid: aa.type.Grid2DLike, **kwargs): raise NotImplementedError - + def deflections_yx_scalar(self, y, x, pixel_scales): if not use_jax: return @@ -140,13 +134,13 @@ def deflections_yx_scalar(self, y, x, pixel_scales): y=y.reshape(1), x=x.reshape(1), shape_native=(1, 1), - pixel_scales=pixel_scales + pixel_scales=pixel_scales, ) return self.deflections_yx_2d_from(g).squeeze() def __eq__(self, other): return self.__dict__ == other.__dict__ and self.__class__ is other.__class__ - + def __hash__(self): return hash(repr(self)) @@ -754,85 +748,73 @@ def jacobian_stack(self, y, x, pixel_scales): return else: return np.stack( - jax.jacfwd( - self.deflections_yx_scalar, - argnums=(0, 1) - )(y, x, pixel_scales) + jax.jacfwd(self.deflections_yx_scalar, argnums=(0, 1))( + y, x, pixel_scales + ) ) - + def jacobian_stack_vector(self, y, x, pixel_scales): if not use_jax: return else: return np.vectorize( - jax.tree_util.Partial( - self.jacobian_stack, - pixel_scales=pixel_scales - ), - signature='(),()->(i,i)' + jax.tree_util.Partial(self.jacobian_stack, pixel_scales=pixel_scales), + signature="(),()->(i,i)", )(y, x) - + def convergence_mag_shear_yx(self, y, x): J = self.jacobian_stack_vector(y, x, 0.05) K = 0.5 * (J[..., 0, 0] + J[..., 1, 1]) mag_shear = 0.5 * np.sqrt( - (J[..., 0, 1] + J[..., 1, 0])**2 + (J[..., 0, 0] - J[..., 1, 1])**2 + (J[..., 0, 1] + J[..., 1, 0]) ** 2 + (J[..., 0, 0] - J[..., 1, 1]) ** 2 ) return K, mag_shear - + @partial(jit, static_argnums=(0,)) def tangential_eigen_value_yx(self, y, x): K, mag_shear = self.convergence_mag_shear_yx(y, x) return 1 - K - mag_shear - + @partial(jit, static_argnums=(0, 3)) def tangential_eigen_value_rt(self, r, theta, centre=(0.0, 0.0)): y = r * np.sin(theta) + centre[0] x = r * np.cos(theta) + centre[1] return self.tangential_eigen_value_yx(y, x) - + @partial(jit, static_argnums=(0, 3)) def grad_r_tangential_eigen_value(self, r, theta, centre=(0.0, 0.0)): # ignore `self` with the `argnums` below - tangential_eigen_part = partial( - self.tangential_eigen_value_rt, - centre=centre - ) + tangential_eigen_part = partial(self.tangential_eigen_value_rt, centre=centre) return np.vectorize( - jax.jacfwd(tangential_eigen_part, argnums=(0,)), - signature='(),()->()' + jax.jacfwd(tangential_eigen_part, argnums=(0,)), signature="(),()->()" )(r, theta)[0] @partial(jit, static_argnums=(0,)) def radial_eigen_value_yx(self, y, x): K, mag_shear = self.convergence_mag_shear_yx(y, x) return 1 - K + mag_shear - + @partial(jit, static_argnums=(0, 3)) def radial_eigen_value_rt(self, r, theta, centre=(0.0, 0.0)): y = r * np.sin(theta) + centre[0] x = r * np.cos(theta) + centre[1] return self.radial_eigen_value_yx(y, x) - + @partial(jit, static_argnums=(0, 3)) def grad_r_radial_eigen_value(self, r, theta, centre=(0.0, 0.0)): # ignore `self` with the `argnums` below - radial_eigen_part = partial( - self.radial_eigen_value_rt, - centre=centre - ) + radial_eigen_part = partial(self.radial_eigen_value_rt, centre=centre) return np.vectorize( - jax.jacfwd(radial_eigen_part, argnums=(0,)), - signature='(),()->()' + jax.jacfwd(radial_eigen_part, argnums=(0,)), signature="(),()->()" )(r, theta)[0] - + def tangential_critical_curve_jax( self, init_r=0.1, init_centre=(0.0, 0.0), n_points=300, n_steps=20, - threshold=1e-5 + threshold=1e-5, ): """ Returns all tangential critical curves of the lensing system, which are computed as follows: @@ -865,8 +847,10 @@ def tangential_critical_curve_jax( r, theta, jax.tree_util.Partial(self.tangential_eigen_value_rt, centre=init_centre), - jax.tree_util.Partial(self.grad_r_tangential_eigen_value, centre=init_centre), - n_steps + jax.tree_util.Partial( + self.grad_r_tangential_eigen_value, centre=init_centre + ), + n_steps, ) new_yx = new_yx + np.array(init_centre) # filter out nan values @@ -876,14 +860,14 @@ def tangential_critical_curve_jax( value = np.abs(self.tangential_eigen_value_yx(new_yx[:, 0], new_yx[:, 1])) gdx = value <= threshold return aa.structures.grids.irregular_2d.Grid2DIrregular(values=new_yx[gdx]) - + def radial_critical_curve_jax( self, init_r=0.01, init_centre=(0.0, 0.0), n_points=300, n_steps=20, - threshold=1e-5 + threshold=1e-5, ): """ Returns all radial critical curves of the lensing system, which are computed as follows: @@ -917,7 +901,7 @@ def radial_critical_curve_jax( theta, jax.tree_util.Partial(self.radial_eigen_value_rt, centre=init_centre), jax.tree_util.Partial(self.grad_r_radial_eigen_value, centre=init_centre), - n_steps + n_steps, ) new_yx = new_yx + np.array(init_centre) # filter out nan values @@ -953,45 +937,51 @@ def jacobian_from(self, grid): a11 = aa.Array2D( values=1.0 - - np.gradient(deflections.native[:, :, 1], grid.native[0, :, 1], axis=1), + - np.gradient( + deflections.native[:, :, 1], grid.native[0, :, 1], axis=1 + ), mask=grid.mask, ) a12 = aa.Array2D( values=-1.0 - * np.gradient(deflections.native[:, :, 1], grid.native[:, 0, 0], axis=0), + * np.gradient( + deflections.native[:, :, 1], grid.native[:, 0, 0], axis=0 + ), mask=grid.mask, ) a21 = aa.Array2D( values=-1.0 - * np.gradient(deflections.native[:, :, 0], grid.native[0, :, 1], axis=1), + * np.gradient( + deflections.native[:, :, 0], grid.native[0, :, 1], axis=1 + ), mask=grid.mask, ) a22 = aa.Array2D( values=1 - - np.gradient(deflections.native[:, :, 0], grid.native[:, 0, 0], axis=0), + - np.gradient( + deflections.native[:, :, 0], grid.native[:, 0, 0], axis=0 + ), mask=grid.mask, ) return [[a11, a12], [a21, a22]] else: A = self.jacobian_stack_vector( - grid.array[:, 0], - grid.array[:, 1], - grid.pixel_scales + grid.array[:, 0], grid.array[:, 1], grid.pixel_scales ) a = np.eye(2).reshape(1, 2, 2) - A return [ [ aa.Array2D(values=a[..., 1, 1], mask=grid.mask), - aa.Array2D(values=a[..., 1, 0], mask=grid.mask) + aa.Array2D(values=a[..., 1, 0], mask=grid.mask), ], [ aa.Array2D(values=a[..., 0, 1], mask=grid.mask), - aa.Array2D(values=a[..., 0, 0], mask=grid.mask) - ] + aa.Array2D(values=a[..., 0, 0], mask=grid.mask), + ], ] # transpose the result diff --git a/autogalaxy/operate/image.py b/autogalaxy/operate/image.py index bc639e421..a1fa126da 100644 --- a/autogalaxy/operate/image.py +++ b/autogalaxy/operate/image.py @@ -2,6 +2,8 @@ import numpy as np from typing import TYPE_CHECKING, Dict, List, Optional +from autoarray import Array2D + if TYPE_CHECKING: from autogalaxy.galaxy.galaxy import Galaxy @@ -34,31 +36,20 @@ def _blurred_image_2d_from( self, image_2d: aa.Array2D, blurring_image_2d: aa.Array2D, - psf: Optional[aa.Kernel2D], - convolver: aa.Convolver, + psf: aa.Kernel2D, ) -> aa.Array2D: - if psf is not None: - return psf.convolved_array_with_mask_from( - array=image_2d.native + blurring_image_2d.native, - mask=image_2d.mask, - ) - - elif convolver is not None: - return convolver.convolve_image( - image=image_2d, blurring_image=blurring_image_2d - ) - else: - raise exc.OperateException( - "A PSF or Convolver was not passed to the `blurred_image_2d_list_from()` function." - ) + values = psf.convolve_image( + image=image_2d, + blurring_image=blurring_image_2d, + ) + return Array2D(values=values, mask=image_2d.mask) def blurred_image_2d_from( self, grid: aa.Grid2D, blurring_grid: aa.Grid2D, - psf: Optional[aa.Kernel2D] = None, - convolver: aa.Convolver = None, + psf: aa.Kernel2D = None, ) -> aa.Array2D: """ Evaluate the light object's 2D image from a input 2D grid of coordinates and convolve it with a PSF. @@ -92,7 +83,6 @@ def blurred_image_2d_from( image_2d=image_2d_not_operated, blurring_image_2d=blurring_image_2d_not_operated, psf=psf, - convolver=convolver, ) if self.has(cls=LightProfileOperated): @@ -224,8 +214,7 @@ def blurred_image_2d_list_from( self, grid: aa.Grid2D, blurring_grid: aa.Grid2D, - psf: Optional[aa.Kernel2D] = None, - convolver: aa.Convolver = None, + psf: aa.Kernel2D = None, ) -> List[aa.Array2D]: """ Evaluate the light object's list of 2D images from a input 2D grid of coordinates and convolve each image with @@ -267,7 +256,6 @@ def blurred_image_2d_list_from( image_2d=image_2d_not_operated, blurring_image_2d=blurring_image_2d_not_operated, psf=psf, - convolver=convolver, ) image_2d_operated = image_2d_operated_list[i] @@ -375,7 +363,7 @@ def galaxy_image_2d_dict_from( raise NotImplementedError def galaxy_blurred_image_2d_dict_from( - self, grid, convolver, blurring_grid + self, grid, psf, blurring_grid ) -> Dict[Galaxy, aa.Array2D]: """ Evaluate the light object's dictionary mapping galaixes to their corresponding 2D images and convolve each @@ -418,7 +406,7 @@ def galaxy_blurred_image_2d_dict_from( galaxy_key ] - blurred_image_2d = convolver.convolve_image( + blurred_image_2d = psf.convolve_image( image=image_2d_not_operated, blurring_image=blurring_image_2d_not_operated, ) diff --git a/autogalaxy/profiles/geometry_profiles.py b/autogalaxy/profiles/geometry_profiles.py index aa9399410..7175851d7 100644 --- a/autogalaxy/profiles/geometry_profiles.py +++ b/autogalaxy/profiles/geometry_profiles.py @@ -4,9 +4,11 @@ if os.environ.get("USE_JAX", "0") == "1": import jax.numpy as np + use_jax = True else: import numpy as np + use_jax = False import autoarray as aa diff --git a/autogalaxy/profiles/light/linear/abstract.py b/autogalaxy/profiles/light/linear/abstract.py index a44194e3b..ff96ed0ca 100644 --- a/autogalaxy/profiles/light/linear/abstract.py +++ b/autogalaxy/profiles/light/linear/abstract.py @@ -143,7 +143,7 @@ def __init__( self, grid: aa.type.Grid1D2DLike, blurring_grid: aa.type.Grid1D2DLike, - convolver: Optional[aa.Convolver], + psf: Optional[aa.Kernel2D], light_profile_list: List[LightProfileLinear], regularization=Optional[aa.reg.Regularization], run_time_dict: Optional[Dict] = None, @@ -184,8 +184,8 @@ def __init__( blurring_grid The blurring grid is all points whose light is outside the data's mask but close enough to the mask that it may be blurred into the mask. This is also used when evaluating the image of each light profile. - convolver - The convolver used to blur the light profile images of each light profile, the output of which + psf + The psf used to blur the light profile images of each light profile, the output of which makes up the columns of the `operated_mapping matrix`. light_profile_list A list of the linear light profiles that are used to fit the data via linear algebra. @@ -210,7 +210,7 @@ def __init__( ) self.blurring_grid = blurring_grid - self.convolver = convolver + self.psf = psf self.light_profile_list = light_profile_list @property @@ -271,7 +271,7 @@ def mapping_matrix(self) -> np.ndarray: @cached_property def operated_mapping_matrix_override(self) -> Optional[np.ndarray]: """ - The inversion object takes the `mapping_matrix` of each linear object and combines it with the `Convolver` + The inversion object takes the `mapping_matrix` of each linear object and combines it with the PSF operator to perform a 2D convolution and compute the `operated_mapping_matrix`. If this property is overwritten this operation is not performed, with the `operated_mapping_matrix` output this @@ -298,7 +298,7 @@ def operated_mapping_matrix_override(self) -> Optional[np.ndarray]: blurring_image_2d = light_profile.image_2d_from(grid=self.blurring_grid) - blurred_image_2d = self.convolver.convolve_image( + blurred_image_2d = self.psf.convolve_image( image=image_2d, blurring_image=blurring_image_2d ) diff --git a/autogalaxy/profiles/light/snr/abstract.py b/autogalaxy/profiles/light/snr/abstract.py index 49b34a766..71350f0c0 100644 --- a/autogalaxy/profiles/light/snr/abstract.py +++ b/autogalaxy/profiles/light/snr/abstract.py @@ -84,7 +84,9 @@ def set_intensity_from( image_2d = self.image_2d_from(grid=grid) if psf is not None: - image_2d = psf.convolved_array_from(array=image_2d) + image_2d = psf.convolve_image_no_blurring( + image=image_2d, mask=image_2d.mask + ) brightest_value = np.max(image_2d) diff --git a/autogalaxy/profiles/light/standard/gaussian.py b/autogalaxy/profiles/light/standard/gaussian.py index a297c6bf5..ec16af0a8 100644 --- a/autogalaxy/profiles/light/standard/gaussian.py +++ b/autogalaxy/profiles/light/standard/gaussian.py @@ -66,7 +66,9 @@ def image_2d_via_radii_from(self, grid_radii: np.ndarray) -> np.ndarray: np.exp( -0.5 * np.square( - np.divide(grid_radii.array, self.sigma / np.sqrt(self.axis_ratio)) + np.divide( + grid_radii.array, self.sigma / np.sqrt(self.axis_ratio) + ) ) ), ) diff --git a/autogalaxy/profiles/mass/abstract/jax_utils.py b/autogalaxy/profiles/mass/abstract/jax_utils.py index 9a20233aa..e241488f3 100644 --- a/autogalaxy/profiles/mass/abstract/jax_utils.py +++ b/autogalaxy/profiles/mass/abstract/jax_utils.py @@ -6,7 +6,7 @@ r1_s1 = [2.5, 2, 1.5, 1, 0.5] -def reg1(z, _ , i_sqrt_pi): +def reg1(z, _, i_sqrt_pi): v = z for coef in r1_s1: v = z - coef / v @@ -18,7 +18,7 @@ def reg1(z, _ , i_sqrt_pi): def reg2(z, sqrt_pi, _): - mz2 = -z**2 + mz2 = -(z**2) f1 = sqrt_pi f2 = 1.0 for s in r2_s1: @@ -64,17 +64,17 @@ def w_f_approx(z): # use a single partial fraction approx for all large abs(z)**2 # to have better approx of the auto-derivatives r1 = (abs_z2 >= 62.0) | ((abs_z2 >= 30.0) & (abs_z2 < 62.0) & (z_imag2 >= 1e-13)) - # region bounds for 5 taken directly from Zaghloul (2017) + # region bounds for 5 taken directly from Zaghloul (2017) # https://dl.acm.org/doi/pdf/10.1145/3119904 r2_1 = (abs_z2 >= 30.0) & (abs_z2 < 62.0) & (z_imag2 < 1e-13) r2_2 = (abs_z2 >= 2.5) & (abs_z2 < 30.0) & (z_imag2 < 0.072) r2 = r2_1 | r2_2 r3 = jnp.logical_not(r1) & jnp.logical_not(r2) - # exploit symmetry to avoid overflow in some regions + # exploit symmetry to avoid overflow in some regions r_flip = z.imag < 0 z_adjust = jnp.where(r_flip, -z, z) - two_exp_zz = 2 * jnp.exp(-z_adjust**2) + two_exp_zz = 2 * jnp.exp(-(z_adjust**2)) args = (z_adjust, sqrt_pi, i_sqrt_pi) wz = jnp.empty_like(z) @@ -82,7 +82,7 @@ def w_f_approx(z): wz = jnp.where(r2, reg2(*args), wz) wz = jnp.where(r3, reg3(*args), wz) - # exploit symmetry to avoid overflow in some regions + # exploit symmetry to avoid overflow in some regions wz = jnp.where(r_flip, two_exp_zz - wz, wz) return wz @@ -91,8 +91,8 @@ def w_f_approx(z): @w_f_approx.defjvp def w_f_approx_jvp(primals, tangents): # define a custom jvp to avoid the issue using `jnp.where` with `jax.grad` - z, = primals - z_dot, = tangents + (z,) = primals + (z_dot,) = tangents primal_out = w_f_approx(z) i_sqrt_pi = 1j / jnp.sqrt(jnp.pi) tangent_out = z_dot * 2 * (i_sqrt_pi - z * primal_out) diff --git a/autogalaxy/profiles/mass/abstract/mge_jax.py b/autogalaxy/profiles/mass/abstract/mge_jax.py index 89d845262..42a568a4c 100644 --- a/autogalaxy/profiles/mass/abstract/mge_jax.py +++ b/autogalaxy/profiles/mass/abstract/mge_jax.py @@ -65,14 +65,11 @@ def eta(p): i = np.arange(1, p, 1) kesi_last = 1 / 2**p k = kesi_last + np.cumsum(np.cumprod((p + 1 - i) / i) * kesi_last) - - kesi_list = np.hstack([ - np.array([0.5]), - np.ones(p), - k[::-1], - np.array([kesi_last]) - ]) - coef = (-1)**np.arange(0, 2 * p + 1, 1) + + kesi_list = np.hstack( + [np.array([0.5]), np.ones(p), k[::-1], np.array([kesi_last])] + ) + coef = (-1) ** np.arange(0, 2 * p + 1, 1) eta_const = 2.0 * np.sqrt(2.0 * np.pi) * 10 ** (p / 3.0) eta_list = coef * kesi_list return eta_const, eta_list @@ -111,10 +108,7 @@ def _decompose_convergence_via_mge( amplitude_list = np.zeros(func_gaussians) f_sigma = eta_constant * np.sum( - eta_n * np.real(func( - sigma_list.reshape(-1, 1) * kesis - )), - axis=1 + eta_n * np.real(func(sigma_list.reshape(-1, 1) * kesis)), axis=1 ) amplitude_list = f_sigma * d_log_sigma / np.sqrt(2.0 * np.pi) amplitude_list = amplitude_list.at[0].multiply(0.5) diff --git a/autogalaxy/profiles/mass/dark/gnfw_virial_mass_gnfw_conc.py b/autogalaxy/profiles/mass/dark/gnfw_virial_mass_gnfw_conc.py index ad466df1a..5be2e7840 100644 --- a/autogalaxy/profiles/mass/dark/gnfw_virial_mass_gnfw_conc.py +++ b/autogalaxy/profiles/mass/dark/gnfw_virial_mass_gnfw_conc.py @@ -49,9 +49,7 @@ def kappa_s_and_scale_radius( ############################## def integrand(r): - return (r**2 / r**inner_slope) * (1 + r / scale_radius_kpc) ** ( - inner_slope - 3 - ) + return (r**2 / r**inner_slope) * (1 + r / scale_radius_kpc) ** (inner_slope - 3) de_c = ( (overdens / 3.0) diff --git a/autogalaxy/profiles/mass/stellar/gaussian.py b/autogalaxy/profiles/mass/stellar/gaussian.py index d0cf66499..bead9e95f 100644 --- a/autogalaxy/profiles/mass/stellar/gaussian.py +++ b/autogalaxy/profiles/mass/stellar/gaussian.py @@ -1,6 +1,7 @@ import copy import numpy as np from autofit.jax_wrapper import use_jax + if use_jax: import jax from scipy.special import wofz @@ -192,11 +193,7 @@ def image_2d_via_radii_from(self, grid_radii: np.ndarray): def axis_ratio(self): axis_ratio = super().axis_ratio if use_jax: - return jax.lax.select( - axis_ratio < 0.9999, - axis_ratio, - 0.9999 - ) + return jax.lax.select(axis_ratio < 0.9999, axis_ratio, 0.9999) else: return axis_ratio if axis_ratio < 0.9999 else 0.9999 diff --git a/test_autogalaxy/conftest.py b/test_autogalaxy/conftest.py index 5fb97e1c1..08653cb85 100644 --- a/test_autogalaxy/conftest.py +++ b/test_autogalaxy/conftest.py @@ -135,11 +135,6 @@ def make_mask_2d_7x7(): return fixtures.make_mask_2d_7x7() -@pytest.fixture(name="convolver_7x7") -def make_convolver_7x7(): - return fixtures.make_convolver_7x7() - - @pytest.fixture(name="mask_2d_7x7_1_pix") def make_mask_2d_7x7_1_pix(): return fixtures.make_mask_2d_7x7_1_pix() diff --git a/test_autogalaxy/imaging/test_fit_imaging.py b/test_autogalaxy/imaging/test_fit_imaging.py index 3b3103fc2..f66d300b0 100644 --- a/test_autogalaxy/imaging/test_fit_imaging.py +++ b/test_autogalaxy/imaging/test_fit_imaging.py @@ -239,13 +239,13 @@ def test__galaxy_model_image_dict(masked_imaging_7x7): g0_blurred_image_2d = g0.blurred_image_2d_from( grid=masked_imaging_7x7.grids.lp, blurring_grid=masked_imaging_7x7.grids.blurring, - convolver=masked_imaging_7x7.convolver, + psf=masked_imaging_7x7.psf, ) g1_blurred_image_2d = g1.blurred_image_2d_from( grid=masked_imaging_7x7.grids.lp, blurring_grid=masked_imaging_7x7.grids.blurring, - convolver=masked_imaging_7x7.convolver, + psf=masked_imaging_7x7.psf, ) assert fit.galaxy_model_image_dict[g0] == pytest.approx(g0_blurred_image_2d, 1.0e-4) diff --git a/test_autogalaxy/imaging/test_simulate_and_fit_imaging.py b/test_autogalaxy/imaging/test_simulate_and_fit_imaging.py index e60c00cce..bd4af6ff5 100644 --- a/test_autogalaxy/imaging/test_simulate_and_fit_imaging.py +++ b/test_autogalaxy/imaging/test_simulate_and_fit_imaging.py @@ -202,7 +202,7 @@ def test__simulate_imaging_data_and_fit__linear_light_profiles_agree_with_standa galaxy_image = galaxy.blurred_image_2d_from( grid=masked_dataset.grids.lp, - convolver=masked_dataset.convolver, + psf=masked_dataset.psf, blurring_grid=masked_dataset.grids.blurring, ) diff --git a/test_autogalaxy/operate/test_image.py b/test_autogalaxy/operate/test_image.py index d2fea8ca9..cd0917de9 100644 --- a/test_autogalaxy/operate/test_image.py +++ b/test_autogalaxy/operate/test_image.py @@ -9,14 +9,16 @@ def test__blurred_image_2d_from( - grid_2d_7x7, blurring_grid_2d_7x7, psf_3x3, convolver_7x7 + grid_2d_7x7, + blurring_grid_2d_7x7, + psf_3x3, ): lp = ag.lp.Sersic(intensity=1.0) image_2d = lp.image_2d_from(grid=grid_2d_7x7) blurring_image_2d = lp.image_2d_from(grid=blurring_grid_2d_7x7) - blurred_image_2d_manual = convolver_7x7.convolve_image( + blurred_image_2d_manual = psf_3x3.convolve_image( image=image_2d, blurring_image=blurring_image_2d ) @@ -29,9 +31,7 @@ def test__blurred_image_2d_from( ) lp_blurred_image_2d = lp.blurred_image_2d_from( - grid=grid_2d_7x7, - blurring_grid=blurring_grid_2d_7x7, - convolver=convolver_7x7, + grid=grid_2d_7x7, blurring_grid=blurring_grid_2d_7x7, psf=psf_3x3 ) assert blurred_image_2d_manual.native == pytest.approx( @@ -55,7 +55,7 @@ def test__blurred_image_2d_from( grid=grid_2d_7x7, psf=psf_3x3, blurring_grid=blurring_grid_2d_7x7 ) - blurred_image_2d_manual_not_operated = convolver_7x7.convolve_image( + blurred_image_2d_manual_not_operated = psf_3x3.convolve_image( image=image_2d_not_operated, blurring_image=blurring_image_2d_not_operated, ) @@ -167,7 +167,9 @@ def test__visibilities_from_grid_and_transformer(grid_2d_7x7, transformer_7x7_7) def test__blurred_image_2d_list_from( - grid_2d_7x7, blurring_grid_2d_7x7, psf_3x3, convolver_7x7 + grid_2d_7x7, + blurring_grid_2d_7x7, + psf_3x3, ): lp_0 = ag.lp.Gaussian(intensity=1.0) lp_1 = ag.lp.Gaussian(intensity=2.0) @@ -194,9 +196,7 @@ def test__blurred_image_2d_list_from( ) blurred_image_2d_list = gal.blurred_image_2d_list_from( - grid=grid_2d_7x7, - blurring_grid=blurring_grid_2d_7x7, - convolver=convolver_7x7, + grid=grid_2d_7x7, blurring_grid=blurring_grid_2d_7x7, psf=psf_3x3 ) assert blurred_image_2d_list[0].native == pytest.approx( @@ -224,9 +224,7 @@ def test__blurred_image_2d_list_from( ) blurred_image_2d_list = gal.blurred_image_2d_list_from( - grid=grid_2d_7x7, - blurring_grid=blurring_grid_2d_7x7, - convolver=convolver_7x7, + grid=grid_2d_7x7, blurring_grid=blurring_grid_2d_7x7, psf=psf_3x3 ) assert blurred_image_2d_list[0].native == pytest.approx( @@ -296,9 +294,7 @@ def test__visibilities_list_from(grid_2d_7x7, transformer_7x7_7): assert (lp_1_visibilities == visibilities_list[1]).all() -def test__galaxy_blurred_image_2d_dict_from( - grid_2d_7x7, blurring_grid_2d_7x7, convolver_7x7 -): +def test__galaxy_blurred_image_2d_dict_from(grid_2d_7x7, blurring_grid_2d_7x7, psf_3x3): lp_0 = ag.lp.Sersic(intensity=1.0) g0 = ag.Galaxy(redshift=0.5, light_profile=lp_0) @@ -312,13 +308,13 @@ def test__galaxy_blurred_image_2d_dict_from( blurred_image_2d_list = galaxies.blurred_image_2d_list_from( grid=grid_2d_7x7, - convolver=convolver_7x7, + psf=psf_3x3, blurring_grid=blurring_grid_2d_7x7, ) blurred_image_dict = galaxies.galaxy_blurred_image_2d_dict_from( grid=grid_2d_7x7, - convolver=convolver_7x7, + psf=psf_3x3, blurring_grid=blurring_grid_2d_7x7, ) @@ -329,7 +325,7 @@ def test__galaxy_blurred_image_2d_dict_from( image_2d = lp_0.image_2d_from(grid=grid_2d_7x7) blurring_image_2d = lp_0.image_2d_from(grid=blurring_grid_2d_7x7) - image_2d_convolved = convolver_7x7.convolve_image( + image_2d_convolved = psf_3x3.convolve_image( image=image_2d, blurring_image=blurring_image_2d ) diff --git a/test_autogalaxy/profiles/light/linear/test_abstract.py b/test_autogalaxy/profiles/light/linear/test_abstract.py index f8ca3ad0e..53ae0c6ee 100644 --- a/test_autogalaxy/profiles/light/linear/test_abstract.py +++ b/test_autogalaxy/profiles/light/linear/test_abstract.py @@ -9,14 +9,14 @@ ) -def test__mapping_matrix_from(grid_2d_7x7, blurring_grid_2d_7x7, convolver_7x7): +def test__mapping_matrix_from(grid_2d_7x7, blurring_grid_2d_7x7, psf_3x3): lp_0 = ag.lp_linear.Sersic(effective_radius=1.0) lp_1 = ag.lp_linear.Sersic(effective_radius=2.0) lp_linear_obj_func_list = LightProfileLinearObjFuncList( grid=grid_2d_7x7, blurring_grid=blurring_grid_2d_7x7, - convolver=convolver_7x7, + psf=psf_3x3, light_profile_list=[lp_0, lp_1], ) @@ -33,15 +33,11 @@ def test__mapping_matrix_from(grid_2d_7x7, blurring_grid_2d_7x7, convolver_7x7): ) lp_0_blurred_image = lp_0.blurred_image_2d_from( - grid=grid_2d_7x7, - blurring_grid=blurring_grid_2d_7x7, - convolver=convolver_7x7, + grid=grid_2d_7x7, blurring_grid=blurring_grid_2d_7x7, psf=psf_3x3 ) lp_1_blurred_image = lp_1.blurred_image_2d_from( - grid=grid_2d_7x7, - blurring_grid=blurring_grid_2d_7x7, - convolver=convolver_7x7, + grid=grid_2d_7x7, blurring_grid=blurring_grid_2d_7x7, psf=psf_3x3 ) assert lp_linear_obj_func_list.operated_mapping_matrix_override[