diff --git a/autogalaxy/operate/deflections.py b/autogalaxy/operate/deflections.py index 37883a8b6..27286bcae 100644 --- a/autogalaxy/operate/deflections.py +++ b/autogalaxy/operate/deflections.py @@ -3,10 +3,8 @@ import jax.numpy as jnp from functools import wraps, partial import logging -import numpy as np from typing import List, Tuple, Union -from autoconf import conf import autoarray as aa @@ -279,21 +277,21 @@ def hessian_from(self, grid, buffer: float = 0.01, deflections_func=None) -> Tup if deflections_func is None: deflections_func = self.deflections_yx_2d_from - grid_shift_y_up = aa.Grid2DIrregular(values=np.zeros(grid.shape)) - grid_shift_y_up[:, 0] = grid[:, 0] + buffer - grid_shift_y_up[:, 1] = grid[:, 1] + grid_shift_y_up = aa.Grid2DIrregular( + values=jnp.stack([grid[:, 0] + buffer, grid[:, 1]], axis=1) + ) - grid_shift_y_down = aa.Grid2DIrregular(values=np.zeros(grid.shape)) - grid_shift_y_down[:, 0] = grid[:, 0] - buffer - grid_shift_y_down[:, 1] = grid[:, 1] + grid_shift_y_down = aa.Grid2DIrregular( + values=jnp.stack([grid[:, 0] - buffer, grid[:, 1]], axis=1) + ) - grid_shift_x_left = aa.Grid2DIrregular(values=np.zeros(grid.shape)) - grid_shift_x_left[:, 0] = grid[:, 0] - grid_shift_x_left[:, 1] = grid[:, 1] - buffer + grid_shift_x_left = aa.Grid2DIrregular( + values=jnp.stack([grid[:, 0], grid[:, 1] - buffer], axis=1) + ) - grid_shift_x_right = aa.Grid2DIrregular(values=np.zeros(grid.shape)) - grid_shift_x_right[:, 0] = grid[:, 0] - grid_shift_x_right[:, 1] = grid[:, 1] + buffer + grid_shift_x_right = aa.Grid2DIrregular( + values=jnp.stack([grid[:, 0], grid[:, 1] + buffer], axis=1) + ) deflections_up = deflections_func(grid=grid_shift_y_up) deflections_down = deflections_func(grid=grid_shift_y_down)