Skip to content

Commit b4f1853

Browse files
Jammy2211Jammy2211
authored andcommitted
simplify unit tests
1 parent 1094154 commit b4f1853

File tree

7 files changed

+142
-162
lines changed

7 files changed

+142
-162
lines changed

test_autoarray/structures/triangles/conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from autoarray.numpy_wrapper import np
22
from autoarray.structures.triangles.array import ArrayTriangles
3+
from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles
34

45
from matplotlib import pyplot as plt
56

@@ -54,3 +55,18 @@ def triangles():
5455
]
5556
),
5657
)
58+
59+
@pytest.fixture
60+
def one_triangle():
61+
return CoordinateArrayTriangles(
62+
coordinates=np.array([[0, 0]]),
63+
side_length=1.0,
64+
)
65+
66+
67+
@pytest.fixture
68+
def two_triangles():
69+
return CoordinateArrayTriangles(
70+
coordinates=np.array([[0, 0], [1, 0]]),
71+
side_length=1.0,
72+
)

test_autoarray/structures/triangles/coordinate/__init__.py

Whitespace-only changes.

test_autoarray/structures/triangles/coordinate/conftest.py

Lines changed: 0 additions & 21 deletions
This file was deleted.

test_autoarray/structures/triangles/coordinate/test_coordinate_jax.py

Lines changed: 0 additions & 127 deletions
This file was deleted.

test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py renamed to test_autoarray/structures/triangles/test_coordinate.py

Lines changed: 121 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
1-
import pytest
2-
1+
from jax import numpy as np
2+
import jax
33
import numpy as np
44

5+
jax.config.update("jax_log_compiles", True)
6+
7+
import pytest
8+
59
from autoarray.structures.triangles.abstract import HEIGHT_FACTOR
6-
from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles
710
from autoarray.structures.triangles.shape import Point
811

12+
from autoarray.structures.triangles.coordinate_array import (
13+
CoordinateArrayTriangles,
14+
)
15+
16+
917

1018
def test__two(two_triangles):
1119

@@ -253,3 +261,113 @@ def test_from_grid_regression():
253261
y = triangles.vertices[:, 1]
254262
assert min(y) <= -4.75
255263
assert max(y) >= 4.75
264+
265+
266+
@pytest.fixture
267+
def one_triangle():
268+
return CoordinateArrayTriangles(
269+
coordinates=np.array([[0, 0]]),
270+
side_length=1.0,
271+
)
272+
273+
274+
@jax.jit
275+
def full_routine(triangles):
276+
neighborhood = triangles.neighborhood()
277+
up_sampled = neighborhood.up_sample()
278+
with_vertices = up_sampled.with_vertices(up_sampled.vertices)
279+
indexes = with_vertices.containing_indices(Point(0.1, 0.1))
280+
return up_sampled.for_indexes(indexes)
281+
282+
283+
# def test_full_routine(one_triangle, compare_with_nans):
284+
# result = full_routine(one_triangle)
285+
#
286+
# assert compare_with_nans(
287+
# result.triangles,
288+
# np.array(
289+
# [
290+
# [
291+
# [0.0, 0.4330126941204071],
292+
# [0.25, 0.0],
293+
# [-0.25, 0.0],
294+
# ]
295+
# ]
296+
# ),
297+
# )
298+
299+
300+
def test_neighborhood(one_triangle):
301+
assert np.allclose(
302+
np.array(jax.jit(one_triangle.neighborhood)().triangles),
303+
np.array(
304+
[
305+
[
306+
[-0.5, -0.4330126941204071],
307+
[-1.0, 0.4330126941204071],
308+
[0.0, 0.4330126941204071],
309+
],
310+
[
311+
[0.0, -1.299038052558899],
312+
[-0.5, -0.4330126941204071],
313+
[0.5, -0.4330126941204071],
314+
],
315+
[
316+
[0.0, 0.4330126941204071],
317+
[0.5, -0.4330126941204071],
318+
[-0.5, -0.4330126941204071],
319+
],
320+
[
321+
[0.5, -0.4330126941204071],
322+
[0.0, 0.4330126941204071],
323+
[1.0, 0.4330126941204071],
324+
],
325+
]
326+
),
327+
)
328+
329+
330+
def test_up_sample(one_triangle):
331+
up_sampled = jax.jit(one_triangle.up_sample)()
332+
assert np.allclose(
333+
np.array(up_sampled.triangles),
334+
np.array(
335+
[
336+
[
337+
[[0.0, -0.4330126941204071], [-0.25, 0.0], [0.25, 0.0]],
338+
[
339+
[0.25, 0.0],
340+
[0.5, -0.4330126941204071],
341+
[0.0, -0.4330126941204071],
342+
],
343+
[
344+
[-0.25, 0.0],
345+
[0.0, -0.4330126941204071],
346+
[-0.5, -0.4330126941204071],
347+
],
348+
[[0.0, 0.4330126941204071], [0.25, 0.0], [-0.25, 0.0]],
349+
]
350+
]
351+
),
352+
)
353+
354+
355+
def test_means(one_triangle):
356+
assert len(one_triangle.means) == 1
357+
358+
up_sampled = one_triangle.up_sample()
359+
neighborhood = up_sampled.neighborhood()
360+
assert np.count_nonzero(~np.isnan(neighborhood.means).any(axis=1)) == 10
361+
362+
363+
ONE_TRIANGLE_AREA = HEIGHT_FACTOR * 0.5
364+
365+
366+
def test_area(one_triangle):
367+
assert one_triangle.area == ONE_TRIANGLE_AREA
368+
assert one_triangle.up_sample().area == ONE_TRIANGLE_AREA
369+
370+
neighborhood = one_triangle.neighborhood()
371+
assert neighborhood.area == 4 * ONE_TRIANGLE_AREA
372+
assert neighborhood.up_sample().area == 4 * ONE_TRIANGLE_AREA
373+
assert neighborhood.neighborhood().area == 10 * ONE_TRIANGLE_AREA

test_autoarray/structures/triangles/test_jax.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,13 @@
1-
from autoarray.structures.triangles.shape import Point
2-
3-
try:
4-
from jax import numpy as np
5-
import jax
1+
from jax import numpy as np
2+
import jax
63

7-
jax.config.update("jax_log_compiles", True)
8-
from autoarray.structures.triangles.array import ArrayTriangles
9-
except ImportError:
10-
import numpy as np
11-
from autoarray.structures.triangles.array import ArrayTriangles
4+
jax.config.update("jax_log_compiles", True)
125

136
import pytest
147

158

16-
pytest.importorskip("jax")
9+
from autoarray.structures.triangles.shape import Point
10+
from autoarray.structures.triangles.array import ArrayTriangles
1711

1812

1913
@pytest.fixture

test_autoarray/structures/triangles/coordinate/test_vertex_coordinates.py renamed to test_autoarray/structures/triangles/test_vertex_coordinates.py

File renamed without changes.

0 commit comments

Comments
 (0)