Skip to content

Commit 35604bb

Browse files
committed
grid_of_closest_from convert to JAX
1 parent 6451bc9 commit 35604bb

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

autoarray/structures/grids/irregular_2d.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -255,14 +255,12 @@ def grid_of_closest_from(self, grid_pair: "Grid2DIrregular") -> "Grid2DIrregular
255255
the `Grid2DIrregular` to the input grid.
256256
"""
257257

258-
grid_of_closest = jnp.zeros((grid_pair.shape[0], 2))
258+
jax_array = jnp.asarray(self.array)
259259

260-
for i in range(grid_pair.shape[0]):
261-
x_distances = np.square(np.subtract(grid_pair[i, 0], self[:, 0]))
262-
y_distances = np.square(np.subtract(grid_pair[i, 1], self[:, 1]))
260+
def closest_point(point):
261+
x_distances = jnp.square(point[0] - jax_array[:, 0])
262+
y_distances = jnp.square(point[1] - jax_array[:, 1])
263+
radial_distances = x_distances + y_distances
264+
return jax_array[jnp.argmin(radial_distances)]
263265

264-
radial_distances = np.add(x_distances, y_distances)
265-
266-
grid_of_closest[i, :] = self[np.argmin(radial_distances), :]
267-
268-
return Grid2DIrregular(values=grid_of_closest)
266+
return jax.vmap(closest_point)(grid_pair.array)

test_autoarray/structures/grids/test_irregular_2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def test__furthest_distances_to_other_coordinates():
9595
def test__grid_of_closest_from():
9696
grid = aa.Grid2DIrregular(values=[(0.0, 0.0), (0.0, 1.0)])
9797

98-
grid_of_closest = grid.grid_of_closest_from(grid_pair=aa.Grid2DIrregular([[0.0, 0.1]]))
98+
grid_of_closest = grid.grid_of_closest_from(grid_pair=aa.Grid2DIrregular(np.array([[0.0, 0.1]])))
9999

100100
assert (grid_of_closest == np.array([[0.0, 0.0]])).all()
101101

0 commit comments

Comments
 (0)