Skip to content

Commit a676b2b

Browse files
Jammy2211Jammy2211
authored andcommitted
jax wrapper moved up to autoconf
1 parent 766333b commit a676b2b

File tree

9 files changed

+7
-52
lines changed

9 files changed

+7
-52
lines changed

autoarray/abstract_ndarray.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
from abc import ABC
66
from abc import abstractmethod
77
import jax.numpy as jnp
8+
from jax._src.tree_util import register_pytree_node
9+
from jax import Array
810

911
from autoconf.fitsable import output_to_fits
1012

11-
from autoarray.numpy_wrapper import register_pytree_node, Array
12-
1313
from typing import TYPE_CHECKING
1414

1515
if TYPE_CHECKING:

autoarray/mask/derive/indexes_2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from autoconf import cached_property
66

7-
from autoarray.numpy_wrapper import register_pytree_node_class
7+
from jax._src.tree_util import register_pytree_node_class
88
from typing import TYPE_CHECKING
99

1010
if TYPE_CHECKING:

autoarray/numpy_wrapper.py

Lines changed: 0 additions & 39 deletions
This file was deleted.

autoarray/operators/over_sampling/over_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import jax.numpy as jnp
33
import jax
4+
from jax._src.tree_util import register_pytree_node_class
45
from typing import Union
56

67
from autoconf import conf
@@ -11,7 +12,6 @@
1112

1213
from autoarray.operators.over_sampling import over_sample_util
1314

14-
from autoarray.numpy_wrapper import register_pytree_node_class
1515

1616

1717
@register_pytree_node_class

autoarray/structures/triangles/coordinate_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22

33
import numpy as np
44
import jax.numpy as jnp
5+
from jax._src.tree_util import register_pytree_node_class
56

67
from autoarray.structures.triangles.abstract import HEIGHT_FACTOR
78
from autoarray.structures.triangles.abstract import AbstractTriangles
89
from autoarray.structures.triangles.array import ArrayTriangles
9-
from autoarray.numpy_wrapper import register_pytree_node_class
1010

1111

1212
@register_pytree_node_class

autoarray/structures/triangles/shape.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from abc import ABC, abstractmethod
2+
from jax._src.tree_util import register_pytree_node_class
23
from typing import List, Tuple
34

45
import numpy as np
56

6-
from autoarray.numpy_wrapper import register_pytree_node_class
7-
87

98
class Shape(ABC):
109
"""

pyproject.toml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ dependencies = [
2929
"astropy>=5.0,<=6.1.2",
3030
"decorator>=4.0.0",
3131
"dill>=0.3.1.1",
32-
"jax==0.4.28",
33-
"jaxlib==0.4.28",
3432
"jaxnnls==1.0.1",
3533
"matplotlib>=3.7.0",
3634
"scipy<=1.14.0",

test_autoarray/conftest.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
1-
import jax
21
import jax.numpy as jnp
32

43
def pytest_configure():
54
_ = jnp.sum(jnp.array([0.0])) # Force backend init
65

7-
jax.config.update("jax_enable_x64", True)
8-
96
import os
107
from os import path
118
import pytest

test_autoarray/structures/triangles/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from autoarray.numpy_wrapper import np
1+
from autoconf.jax_wrapper import np
22
from autoarray.structures.triangles.array import ArrayTriangles
33
from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles
44

0 commit comments

Comments
 (0)