Skip to content

Commit 9c52d10

Browse files
Jammy2211Jammy2211
authored andcommitted
all of inversion tests pass
1 parent 2ecd33b commit 9c52d10

File tree

18 files changed

+236
-135
lines changed

18 files changed

+236
-135
lines changed

autoarray/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from .inversion.pixelization.mappers.rectangular import MapperRectangular
4444
from .inversion.pixelization.mappers.delaunay import MapperDelaunay
4545
from .inversion.pixelization.mappers.voronoi import MapperVoronoi
46+
from .inversion.pixelization.mappers.rectangular_uniform import MapperRectangularUniform
4647
from .inversion.pixelization.image_mesh.abstract import AbstractImageMesh
4748
from .inversion.pixelization.mesh.abstract import AbstractMesh
4849
from .inversion.inversion.imaging.mapping import InversionImagingMapping
@@ -75,6 +76,7 @@
7576
from .operators.over_sampling.over_sampler import OverSampler
7677
from .structures.grids.irregular_2d import Grid2DIrregular
7778
from .structures.mesh.rectangular_2d import Mesh2DRectangular
79+
from .structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform
7880
from .structures.mesh.voronoi_2d import Mesh2DVoronoi
7981
from .structures.mesh.delaunay_2d import Mesh2DDelaunay
8082
from .structures.arrays.kernel_2d import Kernel2D

autoarray/fixtures.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ def make_rectangular_mapper_7x7_3x3():
421421
adapt_data=aa.Array2D.ones(shape_native=(3, 3), pixel_scales=0.1),
422422
)
423423

424-
return aa.MapperRectangular(
424+
return aa.MapperRectangularUniform(
425425
mapper_grids=mapper_grids,
426426
border_relocator=make_border_relocator_2d_7x7(),
427427
regularization=make_regularization_constant(),

autoarray/inversion/pixelization/mappers/factory.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from autoarray.inversion.pixelization.border_relocator import BorderRelocator
55
from autoarray.inversion.regularization.abstract import AbstractRegularization
66
from autoarray.structures.mesh.rectangular_2d import Mesh2DRectangular
7+
from autoarray.structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform
78
from autoarray.structures.mesh.delaunay_2d import Mesh2DDelaunay
89
from autoarray.structures.mesh.voronoi_2d import Mesh2DVoronoi
910

@@ -39,10 +40,19 @@ def mapper_from(
3940
from autoarray.inversion.pixelization.mappers.rectangular import (
4041
MapperRectangular,
4142
)
43+
from autoarray.inversion.pixelization.mappers.rectangular_uniform import (
44+
MapperRectangularUniform,
45+
)
4246
from autoarray.inversion.pixelization.mappers.delaunay import MapperDelaunay
4347
from autoarray.inversion.pixelization.mappers.voronoi import MapperVoronoi
4448

45-
if isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DRectangular):
49+
if isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DRectangularUniform):
50+
return MapperRectangularUniform(
51+
mapper_grids=mapper_grids,
52+
border_relocator=border_relocator,
53+
regularization=regularization,
54+
)
55+
elif isinstance(mapper_grids.source_plane_mesh_grid, Mesh2DRectangular):
4656
return MapperRectangular(
4757
mapper_grids=mapper_grids,
4858
border_relocator=border_relocator,

autoarray/inversion/pixelization/mappers/mapper_util.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,6 @@ def inv_full(U):
161161
return inv_transform(U) * scale + mu
162162

163163
pixel_edges = inv_full(jnp.stack([pixel_edges_1d, pixel_edges_1d]).T)
164-
165-
# lengths along each axis
166164
pixel_lengths = jnp.diff(pixel_edges, axis=0).squeeze() # shape (N_source, 2)
167165

168166
dy = pixel_lengths[:, 0]
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import jax.numpy as jnp
2+
3+
from autoconf import cached_property
4+
5+
from autoarray.inversion.pixelization.mappers.rectangular import MapperRectangular
6+
from autoarray.inversion.pixelization.mappers.abstract import PixSubWeights
7+
8+
from autoarray.inversion.pixelization.mappers import mapper_util
9+
10+
11+
class MapperRectangularUniform(MapperRectangular):
12+
"""
13+
To understand a `Mapper` one must be familiar `Mesh` objects and the `mesh` and `pixelization` packages, where
14+
the four grids grouped in a `MapperGrids` object are explained (`image_plane_data_grid`, `source_plane_data_grid`,
15+
`image_plane_mesh_grid`,`source_plane_mesh_grid`)
16+
17+
If you are unfamliar withe above objects, read through the docstrings of the `pixelization`, `mesh` and
18+
`mapper_grids` packages.
19+
20+
A `Mapper` determines the mappings between the masked data grid's pixels (`image_plane_data_grid` and
21+
`source_plane_data_grid`) and the mesh's pixels (`image_plane_mesh_grid` and `source_plane_mesh_grid`).
22+
23+
The 1D Indexing of each grid is identical in the `data` and `source` frames (e.g. the transformation does not
24+
change the indexing, such that `source_plane_data_grid[0]` corresponds to the transformed value
25+
of `image_plane_data_grid[0]` and so on).
26+
27+
A mapper therefore only needs to determine the index mappings between the `grid_slim` and `mesh_grid`,
28+
noting that associations are made by pairing `source_plane_mesh_grid` with `source_plane_data_grid`.
29+
30+
Mappings are represented in the 2D ndarray `pix_indexes_for_sub_slim_index`, whereby the index of
31+
a pixel on the `mesh_grid` maps to the index of a pixel on the `grid_slim` as follows:
32+
33+
- pix_indexes_for_sub_slim_index[0, 0] = 0: the data's 1st sub-pixel maps to the mesh's 1st pixel.
34+
- pix_indexes_for_sub_slim_index[1, 0] = 3: the data's 2nd sub-pixel maps to the mesh's 4th pixel.
35+
- pix_indexes_for_sub_slim_index[2, 0] = 1: the data's 3rd sub-pixel maps to the mesh's 2nd pixel.
36+
37+
The second dimension of this array (where all three examples above are 0) is used for cases where a
38+
single pixel on the `grid_slim` maps to multiple pixels on the `mesh_grid`. For example, a
39+
`Delaunay` triangulation, where every `grid_slim` pixel maps to three Delaunay pixels (the corners of the
40+
triangles) with varying interpolation weights .
41+
42+
For a `Rectangular` mesh every pixel in the masked data maps to only one pixel, thus the second
43+
dimension of `pix_indexes_for_sub_slim_index` is always of size 1.
44+
45+
The mapper allows us to create a mapping matrix, which is a matrix representing the mapping between every
46+
unmasked data pixel annd the pixels of a mesh. This matrix is the basis of performing an `Inversion`,
47+
which reconstructs the data using the `source_plane_mesh_grid`.
48+
49+
Parameters
50+
----------
51+
mapper_grids
52+
An object containing the data grid and mesh grid in both the data-frame and source-frame used by the
53+
mapper to map data-points to linear object parameters.
54+
regularization
55+
The regularization scheme which may be applied to this linear object in order to smooth its solution,
56+
which for a mapper smooths neighboring pixels on the mesh.
57+
"""
58+
59+
@cached_property
60+
def pix_sub_weights(self) -> PixSubWeights:
61+
"""
62+
Computes the following three quantities describing the mappings between of every sub-pixel in the masked data
63+
and pixel in the `Rectangular` mesh.
64+
65+
- `pix_indexes_for_sub_slim_index`: the mapping of every data pixel (given its `sub_slim_index`)
66+
to mesh pixels (given their `pix_indexes`).
67+
68+
- `pix_sizes_for_sub_slim_index`: the number of mappings of every data pixel to mesh pixels.
69+
70+
- `pix_weights_for_sub_slim_index`: the interpolation weights of every data pixel's mesh
71+
pixel mapping
72+
73+
These are packaged into the class `PixSubWeights` with attributes `mappings`, `sizes` and `weights`.
74+
75+
The `sub_slim_index` refers to the masked data sub-pixels and `pix_indexes` the mesh pixel indexes,
76+
for example:
77+
78+
- `pix_indexes_for_sub_slim_index[0, 0] = 2`: The data's first (index 0) sub-pixel maps to the Rectangular
79+
mesh's third (index 2) pixel.
80+
81+
- `pix_indexes_for_sub_slim_index[2, 0] = 4`: The data's third (index 2) sub-pixel maps to the Rectangular
82+
mesh's fifth (index 4) pixel.
83+
84+
The second dimension of the array `pix_indexes_for_sub_slim_index`, which is 0 in both examples above, is used
85+
for cases where a data pixel maps to more than one mesh pixel (for example a `Delaunay` triangulation
86+
where each data pixel maps to 3 Delaunay triangles with interpolation weights). The weights of multiple mappings
87+
are stored in the array `pix_weights_for_sub_slim_index`.
88+
89+
For a Rectangular pixelization each data sub-pixel maps to a single mesh pixel, thus the second
90+
dimension of the array `pix_indexes_for_sub_slim_index` 1 and all entries in `pix_weights_for_sub_slim_index`
91+
are equal to 1.0.
92+
"""
93+
94+
mappings, weights = (
95+
mapper_util.rectangular_mappings_weights_via_interpolation_from(
96+
shape_native=self.shape_native,
97+
source_plane_mesh_grid=self.source_plane_mesh_grid.array,
98+
source_plane_data_grid=self.source_plane_data_grid.over_sampled,
99+
)
100+
)
101+
102+
return PixSubWeights(
103+
mappings=mappings,
104+
sizes=4 * jnp.ones(len(mappings), dtype="int"),
105+
weights=weights,
106+
)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .abstract import AbstractMesh as Mesh
22
from .rectangular import Rectangular
3+
from .rectangular_uniform import RectangularUniform
34
from .voronoi import Voronoi
45
from .delaunay import Delaunay

autoarray/inversion/pixelization/mesh/rectangular.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numpy as np
2-
from typing import Dict, Optional, Tuple
2+
from typing import Optional, Tuple
33

44

55
from autoarray.structures.grids.uniform_2d import Grid2D
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from autoarray.inversion.pixelization.mesh.rectangular import Rectangular
2+
3+
from typing import Optional
4+
5+
6+
from autoarray.structures.grids.uniform_2d import Grid2D
7+
from autoarray.structures.mesh.rectangular_2d_uniform import Mesh2DRectangularUniform
8+
9+
class RectangularUniform(Rectangular):
10+
11+
def mesh_grid_from(
12+
self,
13+
source_plane_data_grid: Optional[Grid2D] = None,
14+
source_plane_mesh_grid: Optional[Grid2D] = None,
15+
) -> Mesh2DRectangularUniform:
16+
"""
17+
Return the rectangular `source_plane_mesh_grid` as a `Mesh2DRectangular` object, which provides additional
18+
functionality for perform operatons that exploit the geometry of a rectangular pixelization.
19+
20+
Parameters
21+
----------
22+
source_plane_data_grid
23+
The (y,x) grid of coordinates over which the rectangular pixelization is overlaid, where this grid may have
24+
had exterior pixels relocated to its edge via the border.
25+
source_plane_mesh_grid
26+
Not used for a rectangular pixelization, because the pixelization grid in the `source` frame is computed
27+
by overlaying the `source_plane_data_grid` with the rectangular pixelization.
28+
"""
29+
return Mesh2DRectangularUniform.overlay_grid(
30+
shape_native=self.shape, grid=source_plane_data_grid.over_sampled
31+
)

autoarray/plot/mat_plot/two_d.py

Lines changed: 53 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -568,64 +568,71 @@ def _plot_rectangular_mapper(
568568

569569
if pixel_values is not None:
570570

571-
# self.plot_array(
572-
# array=pixel_values,
573-
# visuals_2d=visuals_2d,
574-
# auto_labels=auto_labels,
575-
# bypass=True,
576-
# )
577-
578-
norm = self.cmap.norm_from(
579-
array=pixel_values.array, use_log10=self.use_log10
580-
)
571+
from autoarray.inversion.pixelization.mappers.rectangular_uniform import MapperRectangularUniform
572+
from autoarray.inversion.pixelization.mappers.rectangular import MapperRectangular
581573

582-
edges_transformed = mapper.edges_transformed
574+
if isinstance(mapper, MapperRectangularUniform):
583575

584-
edges_transformed_dense = np.moveaxis(
585-
np.stack(np.meshgrid(*edges_transformed.T)), 0, 2
586-
)
576+
self.plot_array(
577+
array=pixel_values,
578+
visuals_2d=visuals_2d,
579+
auto_labels=auto_labels,
580+
bypass=True,
581+
)
587582

588-
plt.pcolormesh(
589-
edges_transformed_dense[..., 0],
590-
edges_transformed_dense[..., 1],
591-
pixel_values.array.reshape(shape_native),
592-
shading="flat",
593-
norm=norm,
594-
cmap=self.cmap.cmap,
595-
)
583+
else:
596584

597-
if self.colorbar is not False:
585+
norm = self.cmap.norm_from(
586+
array=pixel_values.array, use_log10=self.use_log10
587+
)
598588

599-
cb = self.colorbar.set(
600-
units=self.units,
601-
ax=ax,
589+
edges_transformed = mapper.edges_transformed
590+
591+
edges_transformed_dense = np.moveaxis(
592+
np.stack(np.meshgrid(*edges_transformed.T)), 0, 2
593+
)
594+
595+
plt.pcolormesh(
596+
edges_transformed_dense[..., 0],
597+
edges_transformed_dense[..., 1],
598+
pixel_values.array.reshape(shape_native),
599+
shading="flat",
602600
norm=norm,
603-
cb_unit=auto_labels.cb_unit,
604-
use_log10=self.use_log10,
601+
cmap=self.cmap.cmap,
605602
)
606-
self.colorbar_tickparams.set(cb=cb)
607603

608-
extent_axis = self.axis.config_dict.get("extent")
604+
if self.colorbar is not False:
609605

610-
if extent_axis is None:
611-
extent_axis = extent
606+
cb = self.colorbar.set(
607+
units=self.units,
608+
ax=ax,
609+
norm=norm,
610+
cb_unit=auto_labels.cb_unit,
611+
use_log10=self.use_log10,
612+
)
613+
self.colorbar_tickparams.set(cb=cb)
612614

613-
self.axis.set(extent=extent_axis)
615+
extent_axis = self.axis.config_dict.get("extent")
614616

615-
self.tickparams.set()
616-
self.yticks.set(
617-
min_value=extent_axis[2],
618-
max_value=extent_axis[3],
619-
units=self.units,
620-
pixels=shape_native[0],
621-
)
617+
if extent_axis is None:
618+
extent_axis = extent
622619

623-
self.xticks.set(
624-
min_value=extent_axis[0],
625-
max_value=extent_axis[1],
626-
units=self.units,
627-
pixels=shape_native[1],
628-
)
620+
self.axis.set(extent=extent_axis)
621+
622+
self.tickparams.set()
623+
self.yticks.set(
624+
min_value=extent_axis[2],
625+
max_value=extent_axis[3],
626+
units=self.units,
627+
pixels=shape_native[0],
628+
)
629+
630+
self.xticks.set(
631+
min_value=extent_axis[0],
632+
max_value=extent_axis[1],
633+
units=self.units,
634+
pixels=shape_native[1],
635+
)
629636

630637
if not isinstance(self.text, list):
631638
self.text.set()

autoarray/structures/mesh/rectangular_2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def overlay_grid(
109109
origin=origin,
110110
)
111111

112-
return Mesh2DRectangular(
112+
return cls(
113113
values=grid_slim,
114114
shape_native=shape_native,
115115
pixel_scales=pixel_scales,

0 commit comments

Comments
 (0)