Skip to content

Commit 2d2fe36

Browse files
Jammy2211Jammy2211
authored andcommitted
triangles
1 parent 6224b92 commit 2d2fe36

File tree

7 files changed

+40
-56
lines changed

7 files changed

+40
-56
lines changed

autoarray/abstract_ndarray.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from abc import ABC
66
from abc import abstractmethod
77

8-
from jax._src.tree_util import register_pytree_node
9-
108
import numpy as np
119

1210
from autoconf.fitsable import output_to_fits
@@ -75,14 +73,14 @@ def __init__(self, array, xp=np):
7573
while isinstance(array, AbstractNDArray):
7674
array = array.array
7775
self._array = array
78-
try:
79-
register_pytree_node(
80-
type(self),
81-
self.instance_flatten,
82-
self.instance_unflatten,
83-
)
84-
except ValueError:
85-
pass
76+
# try:
77+
# register_pytree_node(
78+
# type(self),
79+
# self.instance_flatten,
80+
# self.instance_unflatten,
81+
# )
82+
# except ValueError:
83+
# pass
8684

8785
self._xp = xp
8886

autoarray/mask/derive/indexes_2d.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import logging
33
import numpy as np
44

5-
from jax._src.tree_util import register_pytree_node_class
65
from typing import TYPE_CHECKING
76

87
if TYPE_CHECKING:
@@ -14,7 +13,6 @@
1413
logger = logging.getLogger(__name__)
1514

1615

17-
@register_pytree_node_class
1816
class DeriveIndexes2D:
1917

2018
def __init__(self, mask: Mask2D, xp=np):

autoarray/operators/over_sampling/over_sampler.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
22

3-
from jax._src.tree_util import register_pytree_node_class
43
from typing import Union
54

65
from autoconf import conf
@@ -11,7 +10,6 @@
1110
from autoarray.operators.over_sampling import over_sample_util
1211

1312

14-
@register_pytree_node_class
1513
class OverSampler:
1614
def __init__(self, mask: Mask2D, sub_size: Union[int, Array2D]):
1715
"""

autoarray/structures/triangles/array.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
11
import numpy as np
22

3-
from jax.tree_util import register_pytree_node_class
4-
53
from autoarray.structures.triangles.abstract import HEIGHT_FACTOR
64

75
from autoarray.structures.triangles.abstract import AbstractTriangles
86
from autoarray.structures.triangles.shape import Shape
97

108
MAX_CONTAINING_SIZE = 15
119

12-
13-
@register_pytree_node_class
1410
class ArrayTriangles(AbstractTriangles):
1511
def __init__(
1612
self,

autoarray/structures/triangles/coordinate_array.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,11 @@
22

33
import numpy as np
44

5-
from jax._src.tree_util import register_pytree_node_class
6-
75
from autoarray.structures.triangles.abstract import HEIGHT_FACTOR
86
from autoarray.structures.triangles.abstract import AbstractTriangles
97
from autoarray.structures.triangles.array import ArrayTriangles
108

119

12-
@register_pytree_node_class
1310
class CoordinateArrayTriangles(AbstractTriangles, ABC):
1411

1512
def __init__(

autoarray/structures/triangles/shape.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from abc import ABC, abstractmethod
2-
from jax._src.tree_util import register_pytree_node_class
32
from typing import List, Tuple
43

54
import numpy as np
@@ -34,7 +33,6 @@ def mask(self, triangles: np.ndarray) -> np.ndarray:
3433
"""
3534

3635

37-
@register_pytree_node_class
3836
class Point(Shape):
3937
def __init__(self, x: float, y: float):
4038
"""
@@ -107,7 +105,6 @@ def centroid(triangles: np.ndarray):
107105
return (x1 + x2 + x3) / 3, (y1 + y2 + y3) / 3
108106

109107

110-
@register_pytree_node_class
111108
class Circle(Point):
112109
def __init__(
113110
self,

test_autoarray/structures/triangles/test_jax.py

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,25 @@
1-
from jax import numpy as np
21
import jax
3-
4-
jax.config.update("jax_log_compiles", True)
2+
import jax.numpy as jnp
3+
from jax.tree_util import register_pytree_node_class
54

65
import pytest
76

8-
97
from autoarray.structures.triangles.shape import Point
108
from autoarray.structures.triangles.array import ArrayTriangles
119

10+
ArrayTriangles = register_pytree_node_class(ArrayTriangles)
11+
Point = register_pytree_node_class(Point)
1212

1313
@pytest.fixture
1414
def triangles():
1515
return ArrayTriangles(
16-
indices=np.array(
16+
indices=jnp.array(
1717
[
1818
[0, 1, 2],
1919
[1, 2, 3],
2020
]
2121
),
22-
vertices=np.array(
22+
vertices=jnp.array(
2323
[
2424
[0.0, 0.0],
2525
[1.0, 0.0],
@@ -36,37 +36,37 @@ def triangles():
3636
[
3737
(
3838
Point(0.1, 0.1),
39-
np.array(
39+
jnp.array(
4040
[
4141
[0.0, 0.0],
4242
[0.0, 1.0],
4343
[1.0, 0.0],
4444
]
4545
),
46-
np.array([0, -1, -1, -1, -1]),
46+
jnp.array([0, -1, -1, -1, -1]),
4747
),
4848
(
4949
Point(0.6, 0.6),
50-
np.array(
50+
jnp.array(
5151
[
5252
[0.0, 1.0],
5353
[1.0, 0.0],
5454
[1.0, 1.0],
5555
]
5656
),
57-
np.array([1, -1, -1, -1, -1]),
57+
jnp.array([1, -1, -1, -1, -1]),
5858
),
5959
(
6060
Point(0.5, 0.5),
61-
np.array(
61+
jnp.array(
6262
[
6363
[0.0, 0.0],
6464
[0.0, 1.0],
6565
[1.0, 0.0],
6666
[1.0, 1.0],
6767
]
6868
),
69-
np.array([0, 1, -1, -1, -1]),
69+
jnp.array([0, 1, -1, -1, -1]),
7070
),
7171
],
7272
)
@@ -85,46 +85,46 @@ def test_contains_vertices(
8585
"indexes, vertices, indices",
8686
[
8787
(
88-
np.array([0]),
89-
np.array(
88+
jnp.array([0]),
89+
jnp.array(
9090
[
9191
[0.0, 0.0],
9292
[0.0, 1.0],
9393
[1.0, 0.0],
9494
]
9595
),
96-
np.array(
96+
jnp.array(
9797
[
9898
[0, 1, 2],
9999
]
100100
),
101101
),
102102
(
103-
np.array([1]),
104-
np.array(
103+
jnp.array([1]),
104+
jnp.array(
105105
[
106106
[0.0, 1.0],
107107
[1.0, 0.0],
108108
[1.0, 1.0],
109109
]
110110
),
111-
np.array(
111+
jnp.array(
112112
[
113113
[0, 1, 2],
114114
]
115115
),
116116
),
117117
(
118-
np.array([0, 1]),
119-
np.array(
118+
jnp.array([0, 1]),
119+
jnp.array(
120120
[
121121
[0.0, 0.0],
122122
[0.0, 1.0],
123123
[1.0, 0.0],
124124
[1.0, 1.0],
125125
],
126126
),
127-
np.array(
127+
jnp.array(
128128
[
129129
[0, 1, 2],
130130
[1, 2, 3],
@@ -153,13 +153,13 @@ def test_negative_index(
153153
triangles,
154154
compare_with_nans,
155155
):
156-
indexes = np.array([0, -1])
156+
indexes = jnp.array([0, -1])
157157

158158
containing = jax.jit(triangles.for_indexes)(indexes)
159159

160160
assert (
161161
containing.indices
162-
== np.array(
162+
== jnp.array(
163163
[
164164
[-1, -1, -1],
165165
[0, 1, 2],
@@ -168,7 +168,7 @@ def test_negative_index(
168168
).all()
169169
assert compare_with_nans(
170170
containing.vertices,
171-
np.array(
171+
jnp.array(
172172
[
173173
[0.0, 0.0],
174174
[0.0, 1.0],
@@ -186,7 +186,7 @@ def test_up_sample(
186186

187187
assert compare_with_nans(
188188
up_sampled.vertices,
189-
np.array(
189+
jnp.array(
190190
[
191191
[0.0, 0.0],
192192
[0.0, 0.5],
@@ -203,7 +203,7 @@ def test_up_sample(
203203

204204
assert (
205205
up_sampled.indices
206-
== np.array(
206+
== jnp.array(
207207
[
208208
[0, 1, 3],
209209
[1, 2, 4],
@@ -224,12 +224,12 @@ def test_up_sample(
224224
)
225225
def test_simple_neighborhood(offset, compare_with_nans):
226226
triangles = ArrayTriangles(
227-
indices=np.array(
227+
indices=jnp.array(
228228
[
229229
[0, 1, 2],
230230
]
231231
),
232-
vertices=np.array(
232+
vertices=jnp.array(
233233
[
234234
[0.0, 0.0],
235235
[1.0, 0.0],
@@ -242,7 +242,7 @@ def test_simple_neighborhood(offset, compare_with_nans):
242242
assert compare_with_nans(
243243
jax.jit(triangles.neighborhood)().triangles,
244244
(
245-
np.array(
245+
jnp.array(
246246
[
247247
[[-1.0, 1.0], [0.0, 0.0], [0.0, 1.0]],
248248
[[0.0, 0.0], [0.0, 1.0], [1.0, 0.0]],
@@ -260,7 +260,7 @@ def test_neighborhood(triangles, compare_with_nans):
260260

261261
assert compare_with_nans(
262262
neighborhood.vertices,
263-
np.array(
263+
jnp.array(
264264
[
265265
[-1.0, 1.0],
266266
[0.0, 0.0],
@@ -276,7 +276,7 @@ def test_neighborhood(triangles, compare_with_nans):
276276

277277
assert (
278278
neighborhood.indices
279-
== np.array(
279+
== jnp.array(
280280
[
281281
[0, 1, 2],
282282
[1, 2, 5],
@@ -294,7 +294,7 @@ def test_neighborhood(triangles, compare_with_nans):
294294
def test_means(triangles):
295295
means = triangles.means
296296
assert means == pytest.approx(
297-
np.array(
297+
jnp.array(
298298
[
299299
[0.33333333, 0.33333333],
300300
[0.66666667, 0.66666667],

0 commit comments

Comments
 (0)