44
55from abc import ABC
66from abc import abstractmethod
7- import jax . numpy as jnp
7+
88from jax ._src .tree_util import register_pytree_node
99
1010import numpy as np
@@ -88,7 +88,7 @@ def __init__(self, array, xp=np):
8888
8989 def invert (self ):
9090 new = self .copy ()
91- new ._array = jnp .invert (new ._array )
91+ new ._array = self . _xp .invert (new ._array )
9292 return new
9393
9494 @classmethod
@@ -117,7 +117,7 @@ def instance_unflatten(cls, aux_data, children):
117117 setattr (instance , key , value )
118118 return instance
119119
120- def with_new_array (self , array : jnp .ndarray ) -> "AbstractNDArray" :
120+ def with_new_array (self , array : np .ndarray ) -> "AbstractNDArray" :
121121 """
122122 Copy this object but give it a new array.
123123
@@ -137,10 +137,9 @@ def with_new_array(self, array: jnp.ndarray) -> "AbstractNDArray":
137137 new_array ._array = array
138138 return new_array
139139
140- @staticmethod
141- def flip_hdu_for_ds9 (values ):
140+ def flip_hdu_for_ds9 (self , values ):
142141 if conf .instance ["general" ]["fits" ]["flip_for_ds9" ]:
143- return jnp .flipud (values )
142+ return self . _xp .flipud (values )
144143 return values
145144
146145 def copy (self ):
@@ -170,7 +169,7 @@ def __iter__(self):
170169
171170 @to_new_array
172171 def sqrt (self ):
173- return jnp .sqrt (self ._array )
172+ return self . _xp .sqrt (self ._array )
174173
175174 @property
176175 def array (self ):
@@ -333,7 +332,10 @@ def __getattr__(self, item):
333332 )
334333
335334 def __getitem__ (self , item ):
335+
336+ import jax .numpy as jnp
336337 result = self ._array [item ]
338+
337339 if isinstance (item , slice ):
338340 result = self .with_new_array (result )
339341 if isinstance (result , jnp .ndarray ):
@@ -342,6 +344,7 @@ def __getitem__(self, item):
342344
343345 def __setitem__ (self , key , value ):
344346 from jax import Array
347+ import jax .numpy as jnp
345348
346349 if isinstance (key , (jnp .ndarray , AbstractNDArray , Array )):
347350 self ._array = jnp .where (key , value , self ._array )
0 commit comments