Skip to content

Commit 6224b92

Browse files
Jammy2211Jammy2211
authored andcommitted
jax imports deferred
1 parent 9fe34bc commit 6224b92

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

autoarray/abstract_ndarray.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from abc import ABC
66
from abc import abstractmethod
7-
import jax.numpy as jnp
7+
88
from jax._src.tree_util import register_pytree_node
99

1010
import numpy as np
@@ -88,7 +88,7 @@ def __init__(self, array, xp=np):
8888

8989
def invert(self):
9090
new = self.copy()
91-
new._array = jnp.invert(new._array)
91+
new._array = self._xp.invert(new._array)
9292
return new
9393

9494
@classmethod
@@ -117,7 +117,7 @@ def instance_unflatten(cls, aux_data, children):
117117
setattr(instance, key, value)
118118
return instance
119119

120-
def with_new_array(self, array: jnp.ndarray) -> "AbstractNDArray":
120+
def with_new_array(self, array: np.ndarray) -> "AbstractNDArray":
121121
"""
122122
Copy this object but give it a new array.
123123
@@ -137,10 +137,9 @@ def with_new_array(self, array: jnp.ndarray) -> "AbstractNDArray":
137137
new_array._array = array
138138
return new_array
139139

140-
@staticmethod
141-
def flip_hdu_for_ds9(values):
140+
def flip_hdu_for_ds9(self, values):
142141
if conf.instance["general"]["fits"]["flip_for_ds9"]:
143-
return jnp.flipud(values)
142+
return self._xp.flipud(values)
144143
return values
145144

146145
def copy(self):
@@ -170,7 +169,7 @@ def __iter__(self):
170169

171170
@to_new_array
172171
def sqrt(self):
173-
return jnp.sqrt(self._array)
172+
return self._xp.sqrt(self._array)
174173

175174
@property
176175
def array(self):
@@ -333,7 +332,10 @@ def __getattr__(self, item):
333332
)
334333

335334
def __getitem__(self, item):
335+
336+
import jax.numpy as jnp
336337
result = self._array[item]
338+
337339
if isinstance(item, slice):
338340
result = self.with_new_array(result)
339341
if isinstance(result, jnp.ndarray):
@@ -342,6 +344,7 @@ def __getitem__(self, item):
342344

343345
def __setitem__(self, key, value):
344346
from jax import Array
347+
import jax.numpy as jnp
345348

346349
if isinstance(key, (jnp.ndarray, AbstractNDArray, Array)):
347350
self._array = jnp.where(key, value, self._array)

0 commit comments

Comments
 (0)