Skip to content

Commit 67a5830

Browse files
Jammy2211Jammy2211
authored andcommitted
black
1 parent 12b2d72 commit 67a5830

File tree

11 files changed

+66
-55
lines changed

11 files changed

+66
-55
lines changed

autoarray/inversion/inversion/abstract.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,7 @@ def param_range_list_from(self, cls: Type) -> List[List[int]]:
150150
A list of the index range of the parameters of each linear object in the inversion of the input cls type.
151151
"""
152152
return inversion_util.param_range_list_from(
153-
cls=cls,
154-
linear_obj_list=self.linear_obj_list
153+
cls=cls, linear_obj_list=self.linear_obj_list
155154
)
156155

157156
def cls_list_from(self, cls: Type, cls_filtered: Optional[Type] = None) -> List:
@@ -429,20 +428,26 @@ def reconstruction(self) -> np.ndarray:
429428

430429
# Use advanced indexing to select rows/columns
431430
data_vector = self.data_vector[ids_to_keep]
432-
curvature_reg_matrix = self.curvature_reg_matrix[ids_to_keep][:, ids_to_keep]
431+
curvature_reg_matrix = self.curvature_reg_matrix[ids_to_keep][
432+
:, ids_to_keep
433+
]
433434

434435
# Perform reconstruction via fnnls
435-
reconstruction_partial = inversion_util.reconstruction_positive_only_from(
436-
data_vector=data_vector,
437-
curvature_reg_matrix=curvature_reg_matrix,
438-
settings=self.settings,
436+
reconstruction_partial = (
437+
inversion_util.reconstruction_positive_only_from(
438+
data_vector=data_vector,
439+
curvature_reg_matrix=curvature_reg_matrix,
440+
settings=self.settings,
441+
)
439442
)
440443

441444
# Allocate full solution array
442445
reconstruction = jnp.zeros(self.data_vector.shape[0])
443446

444447
# Scatter the partial solution back to the full shape
445-
reconstruction = reconstruction.at[ids_to_keep].set(reconstruction_partial)
448+
reconstruction = reconstruction.at[ids_to_keep].set(
449+
reconstruction_partial
450+
)
446451

447452
return reconstruction
448453

@@ -638,7 +643,9 @@ def log_det_curvature_reg_matrix_term(self) -> float:
638643

639644
try:
640645
return 2.0 * np.sum(
641-
jnp.log(jnp.diag(jnp.linalg.cholesky(self.curvature_reg_matrix_reduced)))
646+
jnp.log(
647+
jnp.diag(jnp.linalg.cholesky(self.curvature_reg_matrix_reduced))
648+
)
642649
)
643650
except np.linalg.LinAlgError as e:
644651
raise exc.InversionException() from e

autoarray/inversion/inversion/factory.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def inversion_from(
2222
dataset: Union[Imaging, Interferometer, DatasetInterface],
2323
linear_obj_list: List[LinearObj],
2424
settings: SettingsInversion = SettingsInversion(),
25-
preloads :Preloads = None,
25+
preloads: Preloads = None,
2626
):
2727
"""
2828
Factory which given an input dataset and list of linear objects, creates an `Inversion`.
@@ -71,7 +71,7 @@ def inversion_imaging_from(
7171
dataset,
7272
linear_obj_list: List[LinearObj],
7373
settings: SettingsInversion = SettingsInversion(),
74-
preloads : Preloads = None,
74+
preloads: Preloads = None,
7575
):
7676
"""
7777
Factory which given an input `Imaging` dataset and list of linear objects, creates an `InversionImaging`.

autoarray/inversion/inversion/interferometer/inversion_interferometer_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1871,4 +1871,4 @@ def sub_slim_indexes_for_pix_index(
18711871
sub_slim_indexes_for_pix_index,
18721872
sub_slim_sizes_for_pix_index,
18731873
sub_slim_weights_for_pix_index,
1874-
)
1874+
)

autoarray/inversion/inversion/interferometer/w_tilde.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def curvature_matrix_diag(self) -> np.ndarray:
130130
sub_slim_indexes_for_pix_index,
131131
sub_slim_sizes_for_pix_index,
132132
sub_slim_weights_for_pix_index,
133-
) = inversion_interferometer_util.sub_slim_indexes_for_pix_index(
133+
) = inversion_interferometer_util.sub_slim_indexes_for_pix_index(
134134
pix_indexes_for_sub_slim_index=mapper.pix_indexes_for_sub_slim_index,
135135
pix_weights_for_sub_slim_index=mapper.pix_weights_for_sub_slim_index,
136136
pix_pixels=mapper.pixels,

autoarray/inversion/inversion/inversion_util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,4 +390,4 @@ def param_range_list_from(cls: Type, linear_obj_list) -> List[List[int]]:
390390

391391
pixel_count += linear_obj.params
392392

393-
return index_list
393+
return index_list

autoarray/inversion/pixelization/mappers/mapper_util.py

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -499,31 +499,33 @@ def adaptive_pixel_signals_from(
499499
M_sub, B = pix_indexes_for_sub_slim_index.shape
500500

501501
# 1) Flatten the per‐mapping tables:
502-
flat_pixidx = pix_indexes_for_sub_slim_index.reshape(-1) # (M_sub*B,)
503-
flat_weights = pixel_weights.reshape(-1) # (M_sub*B,)
502+
flat_pixidx = pix_indexes_for_sub_slim_index.reshape(-1) # (M_sub*B,)
503+
flat_weights = pixel_weights.reshape(-1) # (M_sub*B,)
504504

505505
# 2) Build a matching “parent‐slim” index for each flattened entry:
506-
I_sub = jnp.repeat(jnp.arange(M_sub), B) # (M_sub*B,)
506+
I_sub = jnp.repeat(jnp.arange(M_sub), B) # (M_sub*B,)
507507

508508
# 3) Mask out any k >= pix_size_for_sub_slim_index[i]
509-
valid = (I_sub < 0) # dummy to get shape
509+
valid = I_sub < 0 # dummy to get shape
510510
# better:
511511
valid = (jnp.arange(B)[None, :] < pix_size_for_sub_slim_index[:, None]).reshape(-1)
512512

513513
flat_weights = jnp.where(valid, flat_weights, 0.0)
514-
flat_pixidx = jnp.where(valid, flat_pixidx, pixels) # send invalid indices to an out-of-bounds slot
514+
flat_pixidx = jnp.where(
515+
valid, flat_pixidx, pixels
516+
) # send invalid indices to an out-of-bounds slot
515517

516518
# 4) Look up data & multiply by mapping weights:
517519
flat_data_vals = adapt_data[slim_index_for_sub_slim_index][I_sub] # (M_sub*B,)
518-
flat_contrib = flat_data_vals * flat_weights # (M_sub*B,)
520+
flat_contrib = flat_data_vals * flat_weights # (M_sub*B,)
519521

520522
# 5) Scatter‐add into signal sums and counts:
521-
pixel_signals = jnp.zeros((pixels+1,)).at[flat_pixidx].add(flat_contrib)
522-
pixel_counts = jnp.zeros((pixels+1,)).at[flat_pixidx].add(valid.astype(float))
523+
pixel_signals = jnp.zeros((pixels + 1,)).at[flat_pixidx].add(flat_contrib)
524+
pixel_counts = jnp.zeros((pixels + 1,)).at[flat_pixidx].add(valid.astype(float))
523525

524526
# 6) Drop the extra “out-of-bounds” slot:
525527
pixel_signals = pixel_signals[:pixels]
526-
pixel_counts = pixel_counts[:pixels]
528+
pixel_counts = pixel_counts[:pixels]
527529

528530
# 7) Normalize
529531
pixel_counts = jnp.where(pixel_counts > 0, pixel_counts, 1.0)
@@ -532,7 +534,7 @@ def adaptive_pixel_signals_from(
532534
pixel_signals = jnp.where(max_sig > 0, pixel_signals / max_sig, pixel_signals)
533535

534536
# 8) Exponentiate
535-
return pixel_signals ** signal_scale
537+
return pixel_signals**signal_scale
536538

537539

538540
def mapping_matrix_from(
@@ -652,27 +654,27 @@ def mapped_to_source_via_mapping_matrix_from(
652654
mapping_matrix: np.ndarray, array_slim: np.ndarray
653655
) -> np.ndarray:
654656
"""
655-
Map a masked 2D image (in slim form) into the source plane by summing and averaging
656-
each image-pixel's contribution to its mapped source-pixels.
657-
658-
Each row i of `mapping_matrix` describes how image-pixel i is distributed (with
659-
weights) across the source-pixels j. `array_slim[i]` is then multiplied by those
660-
weights and summed over i to give each source-pixel’s total mapped value; finally,
661-
we divide by the number of nonzero contributions to form an average.
662-
663-
Parameters
664-
----------
665-
mapping_matrix : ndarray of shape (M, N)
666-
mapping_matrix[i, j] ≥ 0 is the weight by which image-pixel i contributes to
667-
source-pixel j. Zero means “no contribution.”
668-
array_slim : ndarray of shape (M,)
669-
The slimmed image values for each image-pixel i.
670-
671-
Returns
672-
-------
673-
mapped_to_source : ndarray of shape (N,)
674-
The averaged, mapped values on each of the N source-pixels.
675-
"""
657+
Map a masked 2D image (in slim form) into the source plane by summing and averaging
658+
each image-pixel's contribution to its mapped source-pixels.
659+
660+
Each row i of `mapping_matrix` describes how image-pixel i is distributed (with
661+
weights) across the source-pixels j. `array_slim[i]` is then multiplied by those
662+
weights and summed over i to give each source-pixel’s total mapped value; finally,
663+
we divide by the number of nonzero contributions to form an average.
664+
665+
Parameters
666+
----------
667+
mapping_matrix : ndarray of shape (M, N)
668+
mapping_matrix[i, j] ≥ 0 is the weight by which image-pixel i contributes to
669+
source-pixel j. Zero means “no contribution.”
670+
array_slim : ndarray of shape (M,)
671+
The slimmed image values for each image-pixel i.
672+
673+
Returns
674+
-------
675+
mapped_to_source : ndarray of shape (N,)
676+
The averaged, mapped values on each of the N source-pixels.
677+
"""
676678
# weighted sums: sum over i of array_slim[i] * mapping_matrix[i, j]
677679
# ==> vector‐matrix multiply: (1×M) dot (M×N) → (N,)
678680
mapped_to_source = array_slim @ mapping_matrix
@@ -722,4 +724,3 @@ def data_weight_total_for_pix_from(
722724

723725
# Sum weights by pixel index
724726
return np.bincount(flat_idxs, weights=flat_weights, minlength=pixels)
725-

autoarray/inversion/regularization/regularization_util.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def weighted_regularization_matrix_from(
309309
#
310310
# return regularization_matrix
311311

312+
312313
def brightness_zeroth_regularization_matrix_from(
313314
regularization_weights: np.ndarray,
314315
) -> np.ndarray:
@@ -330,7 +331,6 @@ def brightness_zeroth_regularization_matrix_from(
330331
return np.diag(regularization_weight_squared)
331332

332333

333-
334334
def reg_split_from(
335335
splitted_mappings: np.ndarray,
336336
splitted_sizes: np.ndarray,
@@ -447,8 +447,7 @@ def pixel_splitted_regularization_matrix_from(
447447

448448
# Outer product of weights and symmetric updates
449449
outer = np.outer(weight, weight) * reg_w
450-
rows, cols = np.meshgrid(mapping, mapping, indexing='ij')
450+
rows, cols = np.meshgrid(mapping, mapping, indexing="ij")
451451
regularization_matrix[rows, cols] += outer
452452

453453
return regularization_matrix
454-

autoarray/preloads.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010

1111
class Preloads:
1212

13-
def __init__(self, mapper_indices: np.ndarray = None, source_pixel_zeroed_indices: np.ndarray = None):
13+
def __init__(
14+
self,
15+
mapper_indices: np.ndarray = None,
16+
source_pixel_zeroed_indices: np.ndarray = None,
17+
):
1418
"""
1519
Stores preloaded arrays and matrices used during pixelized linear inversions, improving both performance
1620
and compatibility with JAX.

test_autoarray/inversion/inversion/imaging/test_imaging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,5 +163,5 @@ def test__w_tilde_checks_noise_map_and_raises_exception_if_preloads_dont_match_n
163163
mapping_matrix=np.ones(matrix_shape), source_plane_data_grid=grid
164164
)
165165
],
166-
settings=aa.SettingsInversion(use_w_tilde=True)
166+
settings=aa.SettingsInversion(use_w_tilde=True),
167167
)

test_autoarray/inversion/inversion/test_factory.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,8 @@ def test__inversion_imaging__source_pixel_zeroed_indices(
198198
linear_obj_list=[rectangular_mapper_7x7_3x3],
199199
settings=aa.SettingsInversion(use_w_tilde=False, use_positive_only_solver=True),
200200
preloads=aa.Preloads(
201-
mapper_indices=range(0, 9),
202-
source_pixel_zeroed_indices=np.array([0])
203-
)
201+
mapper_indices=range(0, 9), source_pixel_zeroed_indices=np.array([0])
202+
),
204203
)
205204

206205
assert inversion.reconstruction.shape[0] == 9
@@ -576,7 +575,7 @@ def test__inversion_matrices__x2_mappers(
576575

577576
assert inversion.reconstruction_dict[rectangular_mapper_7x7_3x3][
578577
4
579-
] == pytest.approx( 0.5000029374603968, 1.0e-4)
578+
] == pytest.approx(0.5000029374603968, 1.0e-4)
580579
assert inversion.reconstruction_dict[delaunay_mapper_9_3x3][4] == pytest.approx(
581580
0.4999970390886761, 1.0e-4
582581
)

0 commit comments

Comments
 (0)