1- from autoconf .jax_wrapper import np
2- from autoarray .structures .triangles .array import ArrayTriangles
3- from autoarray .structures .triangles .coordinate_array import CoordinateArrayTriangles
4-
1+ import jax .numpy as jnp
52from matplotlib import pyplot as plt
6-
7-
83import pytest
94
5+ from autoarray .structures .triangles .array import ArrayTriangles
6+ from autoarray .structures .triangles .coordinate_array import CoordinateArrayTriangles
7+
108
119@pytest .fixture
1210def plot ():
1311 plt .figure (figsize = (8 , 8 ))
1412
1513 def plot (triangles , color = "black" ):
1614 for triangle in triangles :
17- triangle = np .array (triangle )
18- triangle = np .append (triangle , np .array ([triangle [0 ]]), axis = 0 )
15+ triangle = jnp .array (triangle )
16+ triangle = jnp .append (triangle , jnp .array ([triangle [0 ]]), axis = 0 )
1917 plt .plot (triangle [:, 0 ], triangle [:, 1 ], "o-" , color = color )
2018
2119 yield plot
@@ -26,27 +24,27 @@ def plot(triangles, color="black"):
2624@pytest .fixture
2725def compare_with_nans ():
2826 def compare_with_nans_ (arr1 , arr2 ):
29- nan_mask1 = np .isnan (arr1 )
30- nan_mask2 = np .isnan (arr2 )
27+ nan_mask1 = jnp .isnan (arr1 )
28+ nan_mask2 = jnp .isnan (arr2 )
3129
3230 arr1 = arr1 [~ nan_mask1 ]
3331 arr2 = arr2 [~ nan_mask2 ]
3432
35- return np .all (arr1 == arr2 )
33+ return jnp .all (arr1 == arr2 )
3634
3735 return compare_with_nans_
3836
3937
4038@pytest .fixture
4139def triangles ():
4240 return ArrayTriangles (
43- indices = np .array (
41+ indices = jnp .array (
4442 [
4543 [0 , 1 , 2 ],
4644 [1 , 2 , 3 ],
4745 ]
4846 ),
49- vertices = np .array (
47+ vertices = jnp .array (
5048 [
5149 [0.0 , 0.0 ],
5250 [1.0 , 0.0 ],
@@ -60,14 +58,14 @@ def triangles():
6058@pytest .fixture
6159def one_triangle ():
6260 return CoordinateArrayTriangles (
63- coordinates = np .array ([[0 , 0 ]]),
61+ coordinates = jnp .array ([[0 , 0 ]]),
6462 side_length = 1.0 ,
6563 )
6664
6765
6866@pytest .fixture
6967def two_triangles ():
7068 return CoordinateArrayTriangles (
71- coordinates = np .array ([[0 , 0 ], [1 , 0 ]]),
69+ coordinates = jnp .array ([[0 , 0 ], [1 , 0 ]]),
7270 side_length = 1.0 ,
7371 )
0 commit comments