1- from jax import numpy as np
21import jax
3-
4- jax .config . update ( "jax_log_compiles" , True )
2+ import jax . numpy as jnp
3+ from jax .tree_util import register_pytree_node_class
54
65import pytest
76
8-
97from autoarray .structures .triangles .shape import Point
108from autoarray .structures .triangles .array import ArrayTriangles
119
10+ ArrayTriangles = register_pytree_node_class (ArrayTriangles )
11+ Point = register_pytree_node_class (Point )
1212
1313@pytest .fixture
1414def triangles ():
1515 return ArrayTriangles (
16- indices = np .array (
16+ indices = jnp .array (
1717 [
1818 [0 , 1 , 2 ],
1919 [1 , 2 , 3 ],
2020 ]
2121 ),
22- vertices = np .array (
22+ vertices = jnp .array (
2323 [
2424 [0.0 , 0.0 ],
2525 [1.0 , 0.0 ],
@@ -36,37 +36,37 @@ def triangles():
3636 [
3737 (
3838 Point (0.1 , 0.1 ),
39- np .array (
39+ jnp .array (
4040 [
4141 [0.0 , 0.0 ],
4242 [0.0 , 1.0 ],
4343 [1.0 , 0.0 ],
4444 ]
4545 ),
46- np .array ([0 , - 1 , - 1 , - 1 , - 1 ]),
46+ jnp .array ([0 , - 1 , - 1 , - 1 , - 1 ]),
4747 ),
4848 (
4949 Point (0.6 , 0.6 ),
50- np .array (
50+ jnp .array (
5151 [
5252 [0.0 , 1.0 ],
5353 [1.0 , 0.0 ],
5454 [1.0 , 1.0 ],
5555 ]
5656 ),
57- np .array ([1 , - 1 , - 1 , - 1 , - 1 ]),
57+ jnp .array ([1 , - 1 , - 1 , - 1 , - 1 ]),
5858 ),
5959 (
6060 Point (0.5 , 0.5 ),
61- np .array (
61+ jnp .array (
6262 [
6363 [0.0 , 0.0 ],
6464 [0.0 , 1.0 ],
6565 [1.0 , 0.0 ],
6666 [1.0 , 1.0 ],
6767 ]
6868 ),
69- np .array ([0 , 1 , - 1 , - 1 , - 1 ]),
69+ jnp .array ([0 , 1 , - 1 , - 1 , - 1 ]),
7070 ),
7171 ],
7272)
@@ -85,46 +85,46 @@ def test_contains_vertices(
8585 "indexes, vertices, indices" ,
8686 [
8787 (
88- np .array ([0 ]),
89- np .array (
88+ jnp .array ([0 ]),
89+ jnp .array (
9090 [
9191 [0.0 , 0.0 ],
9292 [0.0 , 1.0 ],
9393 [1.0 , 0.0 ],
9494 ]
9595 ),
96- np .array (
96+ jnp .array (
9797 [
9898 [0 , 1 , 2 ],
9999 ]
100100 ),
101101 ),
102102 (
103- np .array ([1 ]),
104- np .array (
103+ jnp .array ([1 ]),
104+ jnp .array (
105105 [
106106 [0.0 , 1.0 ],
107107 [1.0 , 0.0 ],
108108 [1.0 , 1.0 ],
109109 ]
110110 ),
111- np .array (
111+ jnp .array (
112112 [
113113 [0 , 1 , 2 ],
114114 ]
115115 ),
116116 ),
117117 (
118- np .array ([0 , 1 ]),
119- np .array (
118+ jnp .array ([0 , 1 ]),
119+ jnp .array (
120120 [
121121 [0.0 , 0.0 ],
122122 [0.0 , 1.0 ],
123123 [1.0 , 0.0 ],
124124 [1.0 , 1.0 ],
125125 ],
126126 ),
127- np .array (
127+ jnp .array (
128128 [
129129 [0 , 1 , 2 ],
130130 [1 , 2 , 3 ],
@@ -153,13 +153,13 @@ def test_negative_index(
153153 triangles ,
154154 compare_with_nans ,
155155):
156- indexes = np .array ([0 , - 1 ])
156+ indexes = jnp .array ([0 , - 1 ])
157157
158158 containing = jax .jit (triangles .for_indexes )(indexes )
159159
160160 assert (
161161 containing .indices
162- == np .array (
162+ == jnp .array (
163163 [
164164 [- 1 , - 1 , - 1 ],
165165 [0 , 1 , 2 ],
@@ -168,7 +168,7 @@ def test_negative_index(
168168 ).all ()
169169 assert compare_with_nans (
170170 containing .vertices ,
171- np .array (
171+ jnp .array (
172172 [
173173 [0.0 , 0.0 ],
174174 [0.0 , 1.0 ],
@@ -186,7 +186,7 @@ def test_up_sample(
186186
187187 assert compare_with_nans (
188188 up_sampled .vertices ,
189- np .array (
189+ jnp .array (
190190 [
191191 [0.0 , 0.0 ],
192192 [0.0 , 0.5 ],
@@ -203,7 +203,7 @@ def test_up_sample(
203203
204204 assert (
205205 up_sampled .indices
206- == np .array (
206+ == jnp .array (
207207 [
208208 [0 , 1 , 3 ],
209209 [1 , 2 , 4 ],
@@ -224,12 +224,12 @@ def test_up_sample(
224224)
225225def test_simple_neighborhood (offset , compare_with_nans ):
226226 triangles = ArrayTriangles (
227- indices = np .array (
227+ indices = jnp .array (
228228 [
229229 [0 , 1 , 2 ],
230230 ]
231231 ),
232- vertices = np .array (
232+ vertices = jnp .array (
233233 [
234234 [0.0 , 0.0 ],
235235 [1.0 , 0.0 ],
@@ -242,7 +242,7 @@ def test_simple_neighborhood(offset, compare_with_nans):
242242 assert compare_with_nans (
243243 jax .jit (triangles .neighborhood )().triangles ,
244244 (
245- np .array (
245+ jnp .array (
246246 [
247247 [[- 1.0 , 1.0 ], [0.0 , 0.0 ], [0.0 , 1.0 ]],
248248 [[0.0 , 0.0 ], [0.0 , 1.0 ], [1.0 , 0.0 ]],
@@ -260,7 +260,7 @@ def test_neighborhood(triangles, compare_with_nans):
260260
261261 assert compare_with_nans (
262262 neighborhood .vertices ,
263- np .array (
263+ jnp .array (
264264 [
265265 [- 1.0 , 1.0 ],
266266 [0.0 , 0.0 ],
@@ -276,7 +276,7 @@ def test_neighborhood(triangles, compare_with_nans):
276276
277277 assert (
278278 neighborhood .indices
279- == np .array (
279+ == jnp .array (
280280 [
281281 [0 , 1 , 2 ],
282282 [1 , 2 , 5 ],
@@ -294,7 +294,7 @@ def test_neighborhood(triangles, compare_with_nans):
294294def test_means (triangles ):
295295 means = triangles .means
296296 assert means == pytest .approx (
297- np .array (
297+ jnp .array (
298298 [
299299 [0.33333333 , 0.33333333 ],
300300 [0.66666667 , 0.66666667 ],
0 commit comments