File tree Expand file tree Collapse file tree 9 files changed +7
-52
lines changed
Expand file tree Collapse file tree 9 files changed +7
-52
lines changed Original file line number Diff line number Diff line change 55from abc import ABC
66from abc import abstractmethod
77import jax .numpy as jnp
8+ from jax ._src .tree_util import register_pytree_node
9+ from jax import Array
810
911from autoconf .fitsable import output_to_fits
1012
11- from autoarray .numpy_wrapper import register_pytree_node , Array
12-
1313from typing import TYPE_CHECKING
1414
1515if TYPE_CHECKING :
Original file line number Diff line number Diff line change 44
55from autoconf import cached_property
66
7- from autoarray . numpy_wrapper import register_pytree_node_class
7+ from jax . _src . tree_util import register_pytree_node_class
88from typing import TYPE_CHECKING
99
1010if TYPE_CHECKING :
Load Diff This file was deleted.
Original file line number Diff line number Diff line change 11import numpy as np
22import jax .numpy as jnp
33import jax
4+ from jax ._src .tree_util import register_pytree_node_class
45from typing import Union
56
67from autoconf import conf
1112
1213from autoarray .operators .over_sampling import over_sample_util
1314
14- from autoarray .numpy_wrapper import register_pytree_node_class
1515
1616
1717@register_pytree_node_class
Original file line number Diff line number Diff line change 22
33import numpy as np
44import jax .numpy as jnp
5+ from jax ._src .tree_util import register_pytree_node_class
56
67from autoarray .structures .triangles .abstract import HEIGHT_FACTOR
78from autoarray .structures .triangles .abstract import AbstractTriangles
89from autoarray .structures .triangles .array import ArrayTriangles
9- from autoarray .numpy_wrapper import register_pytree_node_class
1010
1111
1212@register_pytree_node_class
Original file line number Diff line number Diff line change 11from abc import ABC , abstractmethod
2+ from jax ._src .tree_util import register_pytree_node_class
23from typing import List , Tuple
34
45import numpy as np
56
6- from autoarray .numpy_wrapper import register_pytree_node_class
7-
87
98class Shape (ABC ):
109 """
Original file line number Diff line number Diff line change @@ -29,8 +29,6 @@ dependencies = [
2929 " astropy>=5.0,<=6.1.2" ,
3030 " decorator>=4.0.0" ,
3131 " dill>=0.3.1.1" ,
32- " jax==0.4.28" ,
33- " jaxlib==0.4.28" ,
3432 " jaxnnls==1.0.1" ,
3533 " matplotlib>=3.7.0" ,
3634 " scipy<=1.14.0" ,
Original file line number Diff line number Diff line change 1- import jax
21import jax .numpy as jnp
32
43def pytest_configure ():
54 _ = jnp .sum (jnp .array ([0.0 ])) # Force backend init
65
7- jax .config .update ("jax_enable_x64" , True )
8-
96import os
107from os import path
118import pytest
Original file line number Diff line number Diff line change 1- from autoarray . numpy_wrapper import np
1+ from autoconf . jax_wrapper import np
22from autoarray .structures .triangles .array import ArrayTriangles
33from autoarray .structures .triangles .coordinate_array import CoordinateArrayTriangles
44
You can’t perform that action at this time.
0 commit comments