Skip to content

Commit 4547f14

Browse files
Jammy2211Jammy2211
authored andcommitted
edge area diff code
1 parent 1bb2ffd commit 4547f14

File tree

6 files changed

+26
-9
lines changed

6 files changed

+26
-9
lines changed

autoarray/inversion/pixelization/image_mesh/overlay.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,9 +201,6 @@ def image_plane_mesh_grid_from(
201201
adapt_data
202202
Not used by this image mesh.
203203
"""
204-
205-
print(mask.pixels_in_mask)
206-
207204
pixel_scales = mask.pixel_scales
208205

209206
grid = mask.derive_grid.unmasked

autoarray/inversion/pixelization/mappers/abstract.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(
2525
mapper_grids: MapperGrids,
2626
regularization: Optional[AbstractRegularization],
2727
border_relocator: BorderRelocator,
28+
preloads=None,
2829
xp=np,
2930
):
3031
"""
@@ -88,6 +89,7 @@ def __init__(
8889

8990
self.border_relocator = border_relocator
9091
self.mapper_grids = mapper_grids
92+
self.preloads = preloads
9193

9294
@property
9395
def params(self) -> int:

autoarray/inversion/pixelization/mappers/factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ def mapper_from(
1313
mapper_grids: MapperGrids,
1414
regularization: Optional[AbstractRegularization],
1515
border_relocator: Optional[BorderRelocator] = None,
16+
preloads=None,
1617
xp=np,
1718
):
1819
"""
@@ -66,5 +67,6 @@ def mapper_from(
6667
mapper_grids=mapper_grids,
6768
border_relocator=border_relocator,
6869
regularization=regularization,
70+
preloads=preloads,
6971
xp=xp,
7072
)

autoarray/inversion/pixelization/mesh/delaunay.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def mesh_grid_from(
4040
self,
4141
source_plane_data_grid=None,
4242
source_plane_mesh_grid=None,
43+
preloads=None,
4344
xp=np,
4445
):
4546
"""
@@ -62,6 +63,7 @@ def mesh_grid_from(
6263
return Mesh2DDelaunay(
6364
values=source_plane_mesh_grid,
6465
source_plane_data_grid_over_sampled=source_plane_data_grid,
66+
preloads=preloads,
6567
_xp=xp,
6668
)
6769

@@ -73,6 +75,7 @@ def mapper_grids_from(
7375
source_plane_mesh_grid: Optional[Grid2DIrregular] = None,
7476
image_plane_mesh_grid: Optional[Grid2DIrregular] = None,
7577
adapt_data: np.ndarray = None,
78+
preloads=None,
7679
xp=np,
7780
) -> MapperGrids:
7881
"""
@@ -132,6 +135,7 @@ def mapper_grids_from(
132135
source_plane_mesh_grid = self.mesh_grid_from(
133136
source_plane_data_grid=relocated_grid.over_sampled,
134137
source_plane_mesh_grid=relocated_mesh_grid,
138+
preloads=preloads,
135139
xp=xp,
136140
)
137141
except ValueError as e:

autoarray/inversion/pixelization/mesh/rectangular.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ def mapper_grids_from(
7373
source_plane_mesh_grid: Grid2D = None,
7474
image_plane_mesh_grid: Grid2D = None,
7575
adapt_data: np.ndarray = None,
76+
preloads=None,
7677
xp=np,
7778
) -> MapperGrids:
7879
"""

autoarray/structures/mesh/delaunay_2d.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from autoarray.inversion.pixelization.mesh import mesh_numba_util
1515

1616

17-
def scipy_delaunay(points_np, query_points_np, max_simplices):
17+
def scipy_delaunay(points_np, query_points_np, source_pixel_zeroed_indices, max_simplices=None):
1818
"""Compute Delaunay simplices (simplices_padded) and Voronoi areas in one call."""
1919

2020
# --- Delaunay mesh using source plane data grid ---
@@ -44,6 +44,12 @@ def scipy_delaunay(points_np, query_points_np, max_simplices):
4444
xp=np,
4545
)
4646

47+
# max_area = np.percentile(barycentric_dual_areas, 90.0)
48+
# barycentric_dual_areas[source_pixel_zeroed_indices] = max_area
49+
50+
max_area = 100.0 * np.max(barycentric_dual_areas)
51+
barycentric_dual_areas[source_pixel_zeroed_indices] = max_area
52+
4753
# ---------- Areas used to weight split points ----------
4854
split_point_areas = 0.5 * np.sqrt(barycentric_dual_areas)
4955

@@ -66,7 +72,7 @@ def scipy_delaunay(points_np, query_points_np, max_simplices):
6672
return points, simplices_padded, mappings, split_points, splitted_mappings
6773

6874

69-
def jax_delaunay(points, query_points):
75+
def jax_delaunay(points, query_points, source_pixel_zeroed_indices):
7076
import jax
7177
import jax.numpy as jnp
7278

@@ -81,8 +87,8 @@ def jax_delaunay(points, query_points):
8187
splitted_mappings_shape = jax.ShapeDtypeStruct((N * 4, 3), jnp.int32)
8288

8389
return jax.pure_callback(
84-
lambda points, qpts: scipy_delaunay(
85-
np.asarray(points), np.asarray(qpts), max_simplices
90+
lambda points, qpts, spzi: scipy_delaunay(
91+
np.asarray(points), np.asarray(qpts), np.asarray(spzi), max_simplices
8692
),
8793
(
8894
points_shape,
@@ -93,6 +99,7 @@ def jax_delaunay(points, query_points):
9399
),
94100
points,
95101
query_points,
102+
source_pixel_zeroed_indices
96103
)
97104

98105

@@ -272,6 +279,7 @@ def __init__(
272279
self,
273280
values: Union[np.ndarray, List],
274281
source_plane_data_grid_over_sampled=None,
282+
preloads=None,
275283
_xp=np,
276284
):
277285
"""
@@ -307,6 +315,7 @@ def __init__(
307315
super().__init__(values, xp=_xp)
308316

309317
self._source_plane_data_grid_over_sampled = source_plane_data_grid_over_sampled
318+
self.preloads = preloads
310319

311320
@property
312321
def geometry(self):
@@ -361,14 +370,16 @@ def delaunay(self) -> "scipy.spatial.Delaunay":
361370

362371
points, simplices, mappings, split_points, splitted_mappings = jax_delaunay(
363372
points=self.mesh_grid_xy,
364-
query_points=self._source_plane_data_grid_over_sampled
373+
query_points=self._source_plane_data_grid_over_sampled,
374+
source_pixel_zeroed_indices=self.preloads.source_pixel_zeroed_indices
365375
)
366376

367377
else:
368378

369379
points, simplices, mappings, split_points, splitted_mappings = scipy_delaunay(
370380
points=self.mesh_grid_xy,
371-
query_points_np=self.source_plane_data_grid_over_sampled
381+
query_points_np=self.source_plane_data_grid_over_sampled,
382+
source_pixel_zeroed_indices=self.preloads.source_pixel_zeroed_indices
372383
)
373384

374385
return DelaunayInterface(

0 commit comments

Comments
 (0)