Skip to content

Commit 09d81cf

Browse files
Jammy2211Jammy2211
authored andcommitted
all inverison unitt ests pass meaning JAx conversion works
1 parent 9e7b1de commit 09d81cf

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

autoarray/dataset/imaging/dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -475,10 +475,10 @@ def apply_over_sampling(
475475
return dataset
476476

477477
def apply_w_tilde(
478-
self,
479-
batch_size: int = 128,
480-
disable_fft_pad: bool = False,
481-
use_jax: bool = False,
478+
self,
479+
batch_size: int = 128,
480+
disable_fft_pad: bool = False,
481+
use_jax: bool = False,
482482
):
483483
"""
484484
The w_tilde formalism of the linear algebra equations precomputes the convolution of every pair of masked

autoarray/inversion/inversion/imaging/w_tilde.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def _data_vector_mapper(self) -> np.ndarray:
7878
if not self.has(cls=AbstractMapper):
7979
return None
8080

81-
data_vector = np.zeros(self.total_params)
81+
data_vector = self._xp.zeros(self.total_params)
8282

8383
mapper_list = self.cls_list_from(cls=AbstractMapper)
8484
mapper_param_range = self.param_range_list_from(cls=AbstractMapper)
@@ -98,7 +98,13 @@ def _data_vector_mapper(self) -> np.ndarray:
9898
)
9999
param_range = mapper_param_range[mapper_index]
100100

101-
data_vector[param_range[0] : param_range[1],] = data_vector_mapper
101+
start = param_range[0]
102+
end = param_range[1]
103+
104+
if self._xp is np:
105+
data_vector[start:end] = data_vector_mapper
106+
else:
107+
data_vector = data_vector.at[start:end].set(data_vector_mapper)
102108

103109
return data_vector
104110

@@ -186,7 +192,7 @@ def _data_vector_func_list_and_mapper(self) -> np.ndarray:
186192
separation of functions enables the `data_vector` to be preloaded in certain circumstances.
187193
"""
188194

189-
data_vector = np.array(self._data_vector_mapper)
195+
data_vector = self._xp.array(self._data_vector_mapper)
190196

191197
linear_func_param_range = self.param_range_list_from(
192198
cls=AbstractLinearObjFuncList

0 commit comments

Comments
 (0)