Skip to content

Commit e27efb5

Browse files
authored
Merge pull request #252 from PyAutoLabs/feature/unit-test-profiling
perf: speed up unit tests 63% by removing JAX from triangle tests
2 parents 2448ebf + 44b8145 commit e27efb5

File tree

8 files changed

+40
-493
lines changed

8 files changed

+40
-493
lines changed

test_autoarray/inversion/inversion/imaging/test_inversion_imaging_util.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,13 +260,13 @@ def test__data_vector_via_weighted_data_two_methods_agree():
260260

261261
def test__curvature_matrix_via_psf_weighted_noise_two_methods_agree():
262262

263-
mask = aa.Mask2D.circular(shape_native=(51, 51), pixel_scales=0.1, radius=2.0)
263+
mask = aa.Mask2D.circular(shape_native=(21, 21), pixel_scales=0.1, radius=0.8)
264264

265265
noise_map = np.random.uniform(size=mask.shape_native)
266266
noise_map = aa.Array2D(values=noise_map, mask=mask)
267267

268268
kernel = aa.Convolver.from_gaussian(
269-
shape_native=(7, 7), pixel_scales=mask.pixel_scales, sigma=1.0, normalize=True
269+
shape_native=(5, 5), pixel_scales=mask.pixel_scales, sigma=1.0, normalize=True
270270
)
271271

272272
psf = kernel
@@ -277,7 +277,7 @@ def test__curvature_matrix_via_psf_weighted_noise_two_methods_agree():
277277
psf=psf.kernel.native,
278278
)
279279

280-
mesh = aa.mesh.RectangularAdaptDensity(shape=(20, 20))
280+
mesh = aa.mesh.RectangularAdaptDensity(shape=(8, 8))
281281

282282
interpolator = mesh.interpolator_from(
283283
source_plane_data_grid=mask.derive_grid.unmasked,
Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
import jax.numpy as jnp
1+
import numpy as np
22
from matplotlib import pyplot as plt
33
import pytest
44

5-
from autoarray.structures.triangles.array import ArrayTriangles
6-
from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles
5+
from autoarray.structures.triangles.array_np import ArrayTrianglesNp
6+
from autoarray.structures.triangles.coordinate_array_np import CoordinateArrayTrianglesNp
77

88

99
@pytest.fixture
@@ -12,8 +12,8 @@ def plot():
1212

1313
def plot(triangles, color="black"):
1414
for triangle in triangles:
15-
triangle = jnp.array(triangle)
16-
triangle = jnp.append(triangle, jnp.array([triangle[0]]), axis=0)
15+
triangle = np.array(triangle)
16+
triangle = np.append(triangle, np.array([triangle[0]]), axis=0)
1717
plt.plot(triangle[:, 0], triangle[:, 1], "o-", color=color)
1818

1919
yield plot
@@ -24,27 +24,27 @@ def plot(triangles, color="black"):
2424
@pytest.fixture
2525
def compare_with_nans():
2626
def compare_with_nans_(arr1, arr2):
27-
nan_mask1 = jnp.isnan(arr1)
28-
nan_mask2 = jnp.isnan(arr2)
27+
nan_mask1 = np.isnan(arr1)
28+
nan_mask2 = np.isnan(arr2)
2929

3030
arr1 = arr1[~nan_mask1]
3131
arr2 = arr2[~nan_mask2]
3232

33-
return jnp.all(arr1 == arr2)
33+
return np.all(arr1 == arr2)
3434

3535
return compare_with_nans_
3636

3737

3838
@pytest.fixture
3939
def triangles():
40-
return ArrayTriangles(
41-
indices=jnp.array(
40+
return ArrayTrianglesNp(
41+
indices=np.array(
4242
[
4343
[0, 1, 2],
4444
[1, 2, 3],
4545
]
4646
),
47-
vertices=jnp.array(
47+
vertices=np.array(
4848
[
4949
[0.0, 0.0],
5050
[1.0, 0.0],
@@ -57,15 +57,15 @@ def triangles():
5757

5858
@pytest.fixture
5959
def one_triangle():
60-
return CoordinateArrayTriangles(
61-
coordinates=jnp.array([[0, 0]]),
60+
return CoordinateArrayTrianglesNp(
61+
coordinates=np.array([[0, 0]]),
6262
side_length=1.0,
6363
)
6464

6565

6666
@pytest.fixture
6767
def two_triangles():
68-
return CoordinateArrayTriangles(
69-
coordinates=jnp.array([[0, 0], [1, 0]]),
68+
return CoordinateArrayTrianglesNp(
69+
coordinates=np.array([[0, 0], [1, 0]]),
7070
side_length=1.0,
7171
)

test_autoarray/structures/triangles/test_area.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import numpy as np
22
import pytest
33

4-
from autoarray.structures.triangles.array import ArrayTriangles
4+
from autoarray.structures.triangles.array_np import ArrayTrianglesNp as ArrayTriangles
55
from autoarray.structures.triangles.shape import Triangle, Circle, Square, Polygon
66

77

test_autoarray/structures/triangles/test_coordinate.py

Lines changed: 17 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
import numpy as np
2-
from jax.tree_util import register_pytree_node_class
32

43
import pytest
54

65
from autoarray.structures.triangles.abstract import HEIGHT_FACTOR
76

8-
from autoarray.structures.triangles.coordinate_array import (
9-
CoordinateArrayTriangles,
7+
from autoarray.structures.triangles.coordinate_array_np import (
8+
CoordinateArrayTrianglesNp,
109
)
1110

12-
CoordinateArrayTriangles = register_pytree_node_class(CoordinateArrayTriangles)
13-
1411

1512
def test__two(two_triangles):
1613

@@ -52,7 +49,7 @@ def test__trivial_triangles(one_triangle):
5249

5350

5451
def test__above():
55-
triangles = CoordinateArrayTriangles(
52+
triangles = CoordinateArrayTrianglesNp(
5653
coordinates=np.array([[0, 1]]),
5754
side_length=1.0,
5855
)
@@ -87,7 +84,7 @@ def test__above():
8784

8885
@pytest.fixture
8986
def upside_down():
90-
return CoordinateArrayTriangles(
87+
return CoordinateArrayTrianglesNp(
9188
coordinates=np.array([[1, 0]]),
9289
side_length=1.0,
9390
)
@@ -279,104 +276,37 @@ def test_means(one_triangle):
279276

280277

281278
def test_triangles_touch():
282-
triangles = CoordinateArrayTriangles(
279+
triangles = CoordinateArrayTrianglesNp(
283280
np.array([[0, 0], [2, 0]]),
284281
)
285282

286283
assert max(triangles.triangles[0][:, 0]) == min(triangles.triangles[1][:, 0])
287284

288-
triangles = CoordinateArrayTriangles(
285+
triangles = CoordinateArrayTrianglesNp(
289286
np.array([[0, 0], [0, 1]]),
290287
)
291288
assert max(triangles.triangles[0][:, 1]) == min(triangles.triangles[1][:, 1])
292289

293290

294291
def test_from_grid_regression():
295-
triangles = CoordinateArrayTriangles.for_limits_and_scale(
296-
x_min=-4.75,
297-
x_max=4.75,
298-
y_min=-4.75,
299-
y_max=4.75,
300-
scale=0.5,
292+
triangles = CoordinateArrayTrianglesNp.for_limits_and_scale(
293+
x_min=-2.0,
294+
x_max=2.0,
295+
y_min=-2.0,
296+
y_max=2.0,
297+
scale=1.5,
301298
)
302299

303300
x = triangles.vertices[:, 0]
304-
assert min(x) <= -4.75
305-
assert max(x) >= 4.75
301+
assert min(x) <= -2.0
302+
assert max(x) >= 2.0
306303

307304
y = triangles.vertices[:, 1]
308-
assert min(y) <= -4.75
309-
assert max(y) >= 4.75
310-
305+
assert min(y) <= -2.0
306+
assert max(y) >= 2.0
311307

312-
@pytest.fixture
313-
def one_triangle():
314-
return CoordinateArrayTriangles(
315-
coordinates=np.array([[0, 0]]),
316-
side_length=1.0,
317-
)
318308

319-
320-
def test_neighborhood(one_triangle):
321-
import jax
322-
323-
assert np.allclose(
324-
np.array(jax.jit(one_triangle.neighborhood)().triangles),
325-
np.array(
326-
[
327-
[
328-
[-0.5, -0.4330126941204071],
329-
[-1.0, 0.4330126941204071],
330-
[0.0, 0.4330126941204071],
331-
],
332-
[
333-
[0.0, -1.299038052558899],
334-
[-0.5, -0.4330126941204071],
335-
[0.5, -0.4330126941204071],
336-
],
337-
[
338-
[0.0, 0.4330126941204071],
339-
[0.5, -0.4330126941204071],
340-
[-0.5, -0.4330126941204071],
341-
],
342-
[
343-
[0.5, -0.4330126941204071],
344-
[0.0, 0.4330126941204071],
345-
[1.0, 0.4330126941204071],
346-
],
347-
]
348-
),
349-
)
350-
351-
352-
def test_up_sample(one_triangle):
353-
import jax
354-
355-
up_sampled = jax.jit(one_triangle.up_sample)()
356-
assert np.allclose(
357-
np.array(up_sampled.triangles),
358-
np.array(
359-
[
360-
[
361-
[[0.0, -0.4330126941204071], [-0.25, 0.0], [0.25, 0.0]],
362-
[
363-
[0.25, 0.0],
364-
[0.5, -0.4330126941204071],
365-
[0.0, -0.4330126941204071],
366-
],
367-
[
368-
[-0.25, 0.0],
369-
[0.0, -0.4330126941204071],
370-
[-0.5, -0.4330126941204071],
371-
],
372-
[[0.0, 0.4330126941204071], [0.25, 0.0], [-0.25, 0.0]],
373-
]
374-
]
375-
),
376-
)
377-
378-
379-
def test_means(one_triangle):
309+
def test_means_up_sampled(one_triangle):
380310
assert len(one_triangle.means) == 1
381311

382312
up_sampled = one_triangle.up_sample()

test_autoarray/structures/triangles/test_extended_source.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import pytest
22
import numpy as np
33

4-
from autoarray.structures.triangles.array import ArrayTriangles
4+
from autoarray.structures.triangles.array_np import ArrayTrianglesNp
55
from autoarray.structures.triangles.shape import Circle
66

77

88
@pytest.fixture
99
def triangles():
10-
return ArrayTriangles(
10+
return ArrayTrianglesNp(
1111
indices=np.array(
1212
[
1313
[0, 1, 2],
@@ -49,7 +49,7 @@ def test_small_point(triangles, point, indices):
4949
radius=0.001,
5050
)
5151
)
52-
assert [i for i in containing_triangles.tolist() if i != -1] == indices
52+
assert containing_triangles.tolist() == indices
5353

5454

5555
@pytest.mark.parametrize(
@@ -72,4 +72,4 @@ def test_large_circle(
7272
radius=radius,
7373
)
7474
)
75-
assert [i for i in containing_triangles.tolist() if i != -1] == indices
75+
assert containing_triangles.tolist() == indices

0 commit comments

Comments
 (0)