Skip to content

Commit 850695c

Browse files
Jammy2211Jammy2211
authored andcommitted
added weight floor to mesh adapt
1 parent a5e3e78 commit 850695c

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

autoarray/inversion/pixelization/mesh/rectangular.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def requires_image_mesh(self):
158158

159159
class RectangularSource(RectangularMagnification):
160160

161-
def __init__(self, shape: Tuple[int, int] = (3, 3), weight_power: float = 1.0):
161+
def __init__(self, shape: Tuple[int, int] = (3, 3), weight_power: float = 1.0, weight_floor : float = 0.0):
162162
"""
163163
A uniform mesh of rectangular pixels, which without interpolation are paired with a 2D grid of (y,x)
164164
coordinates.
@@ -190,6 +190,7 @@ def __init__(self, shape: Tuple[int, int] = (3, 3), weight_power: float = 1.0):
190190
super().__init__(shape=shape)
191191

192192
self.weight_power = weight_power
193+
self.weight_floor = weight_floor
193194

194195
def mesh_weight_map_from(self, adapt_data, xp=np) -> np.ndarray:
195196
"""
@@ -205,5 +206,6 @@ def mesh_weight_map_from(self, adapt_data, xp=np) -> np.ndarray:
205206
mesh_weight_map = xp.asarray(adapt_data.array)
206207
mesh_weight_map = xp.clip(mesh_weight_map, 1e-12, None)
207208
mesh_weight_map = mesh_weight_map**self.weight_power
209+
mesh_weight_map[mesh_weight_map < self.weight_floor] = self.weight_floor
208210
mesh_weight_map /= xp.sum(mesh_weight_map)
209211
return mesh_weight_map

autoarray/operators/transformer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -440,11 +440,14 @@ def transform_mapping_matrix(self, mapping_matrix: np.ndarray, xp=np) -> np.ndar
440440
- Each column of the input mapping matrix is reshaped into the native 2D image grid before transformation.
441441
- This method repeatedly calls `visibilities_from` for each column, which may be computationally intensive.
442442
"""
443-
transformed_mapping_matrix = 0 + 0j * np.zeros(
443+
transformed_mapping_matrix = 0 + 0j * xp.zeros(
444444
(self.uv_wavelengths.shape[0], mapping_matrix.shape[1])
445445
)
446446

447447
for source_pixel_1d_index in range(mapping_matrix.shape[1]):
448+
449+
print("hi")
450+
448451
image_2d = array_2d_util.array_2d_native_from(
449452
array_2d_slim=mapping_matrix[:, source_pixel_1d_index],
450453
mask_2d=self.grid.mask,

0 commit comments

Comments
 (0)