Skip to content

Commit de9a08a

Browse files
Jammy2211Jammy2211
authored andcommitted
xp module not class attribute
1 parent bdeabaf commit de9a08a

File tree

8 files changed

+78
-8
lines changed

8 files changed

+78
-8
lines changed

autoarray/abstract_ndarray.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,16 @@ def __init__(self, array, xp=np):
7373
while isinstance(array, AbstractNDArray):
7474
array = array.array
7575
self._array = array
76-
self._xp = xp
76+
77+
self.use_jax = xp is not np
78+
79+
@property
80+
def _xp(self):
81+
if self.use_jax:
82+
import jax.numpy as jnp
83+
84+
return jnp
85+
return np
7786

7887
def invert(self):
7988
new = self.copy()

autoarray/fit/fit_dataset.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,15 @@ def __init__(
158158
self.dataset.grids.blurring
159159
self.dataset.grids.border_relocator
160160

161-
self._xp = xp
161+
self.use_jax = xp is not np
162+
163+
@property
164+
def _xp(self):
165+
if self.use_jax:
166+
import jax.numpy as jnp
167+
168+
return jnp
169+
return np
162170

163171
@property
164172
def mask(self) -> Mask2D:

autoarray/inversion/inversion/abstract.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,15 @@ def __init__(
7575

7676
self.preloads = preloads or Preloads()
7777

78-
self._xp = xp
78+
self.use_jax = xp is not np
79+
80+
@property
81+
def _xp(self):
82+
if self.use_jax:
83+
import jax.numpy as jnp
84+
85+
return jnp
86+
return np
7987

8088
@property
8189
def data(self):

autoarray/inversion/linear_obj/linear_obj.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,16 @@ def __init__(self, regularization: Optional[AbstractRegularization], xp=np):
2727
The regularization scheme which may be applied to this linear object in order to smooth its solution.
2828
"""
2929
self.regularization = regularization
30-
self._xp = xp
30+
31+
self.use_jax = xp is not np
32+
33+
@property
34+
def _xp(self):
35+
if self.use_jax:
36+
import jax.numpy as jnp
37+
38+
return jnp
39+
return np
3140

3241
@property
3342
def params(self) -> int:

autoarray/mask/abstract_mask.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,16 @@ def __init__(
5757

5858
self.pixel_scales = pixel_scales
5959
self.origin = origin
60-
self._xp = xp
60+
61+
self.use_jax = xp is not np
62+
63+
@property
64+
def _xp(self):
65+
if self.use_jax:
66+
import jax.numpy as jnp
67+
68+
return jnp
69+
return np
6170

6271
@property
6372
def mask(self):

autoarray/mask/derive/grid_2d.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,16 @@ def __init__(self, mask: Mask2D, xp=np):
6161
print(derive_grid_2d.border)
6262
"""
6363
self.mask = mask
64-
self._xp = xp
64+
65+
self.use_jax = xp is not np
66+
67+
@property
68+
def _xp(self):
69+
if self.use_jax:
70+
import jax.numpy as jnp
71+
72+
return jnp
73+
return np
6574

6675
def tree_flatten(self):
6776
return (self.mask,), ()

autoarray/mask/derive/indexes_2d.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,16 @@ def __init__(self, mask: Mask2D, xp=np):
6262
print(derive_indexes_2d.edge_native)
6363
"""
6464
self.mask = mask
65-
self._xp = xp
65+
66+
self.use_jax = xp is not np
67+
68+
@property
69+
def _xp(self):
70+
if self.use_jax:
71+
import jax.numpy as jnp
72+
73+
return jnp
74+
return np
6675

6776
def tree_flatten(self):
6877
return (self.mask,), ()

autoarray/structures/decorators/abstract.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,19 @@ def __init__(self, func, obj, grid, xp=np, *args, **kwargs):
5252
self.func = func
5353
self.obj = obj
5454
self.grid = grid
55-
self._xp = xp
5655
self.args = args
5756
self.kwargs = kwargs
5857

58+
self.use_jax = xp is not np
59+
60+
@property
61+
def _xp(self):
62+
if self.use_jax:
63+
import jax.numpy as jnp
64+
65+
return jnp
66+
return np
67+
5968
@property
6069
def mask(self) -> Union[Mask1D, Mask2D]:
6170
return self.grid.mask

0 commit comments

Comments
 (0)