Skip to content

Commit bb5bc27

Browse files
Jammy2211Jammy2211
authored andcommitted
fix numpy interface to vertex areas
1 parent e1ee5d3 commit bb5bc27

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

autoarray/structures/mesh/triangulation_2d.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
2+
import scipy.spatial
33
from typing import List, Union, Tuple
44

55
from autoconf import cached_property
@@ -12,8 +12,7 @@
1212
from autoarray.structures.grids import grid_2d_util
1313

1414

15-
import numpy as np
16-
import scipy.spatial
15+
1716

1817

1918
def scipy_delaunay_padded(points_np, max_simplices):
@@ -146,7 +145,10 @@ def vertex_areas_from_delaunay(points, simplices, xp=np):
146145
vertex_area = xp.zeros(n_pts)
147146

148147
# Scatter-add: NumPy and JAX both support this API!
149-
vertex_area = vertex_area.at[scatter_idx].add(scatter_vals)
148+
if xp.__name__.startswith("jax"):
149+
vertex_area = vertex_area.at[scatter_idx].add(scatter_vals)
150+
else:
151+
np.add.at(vertex_area, scatter_idx, scatter_vals)
150152

151153
return vertex_area
152154

@@ -289,6 +291,7 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
289291
to compute the Voronoi mesh are ill posed. These exceptions are caught and combined into a single
290292
`MeshException`, which helps exception handling in the `inversion` package.
291293
"""
294+
292295
mesh_grid = self._xp.stack([self.array[:, 0], self.array[:, 1]]).T
293296

294297
if self._xp.__name__.startswith("jax"):

test_autoarray/structures/mesh/test_delaunay.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import autoarray as aa
55

6+
from autoarray.structures.mesh.triangulation_2d import vertex_areas_from_delaunay
67

78
def test__edge_pixel_list():
89
grid = np.array(
@@ -58,3 +59,4 @@ def test__interpolated_array_from():
5859
assert interpolated_array.native == pytest.approx(
5960
np.array([[1.0, 1.907216], [1.0, 1.0], [1.0, 1.0]]), 1.0e-4
6061
)
62+

0 commit comments

Comments
 (0)