Skip to content

Commit 0c31531

Browse files
Jammy2211Jammy2211
authored andcommitted
black
1 parent 2d2fe36 commit 0c31531

File tree

7 files changed

+20
-18
lines changed

7 files changed

+20
-18
lines changed

autoarray/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from autoconf import jax_wrapper
12
from autoconf.dictable import register_parser
23
from autoconf import conf
34

autoarray/abstract_ndarray.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def __getattr__(self, item):
332332
def __getitem__(self, item):
333333

334334
import jax.numpy as jnp
335+
335336
result = self._array[item]
336337

337338
if isinstance(item, slice):

autoarray/config/general.yaml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
jax:
2-
use_jax: true # If True, uses JAX internally, whereas False uses normal Numpy.
31
fits:
42
flip_for_ds9: false # If True, the image is flipped before output to a .fits file, which is useful for viewing in DS9.
53
psf:

autoarray/structures/triangles/array.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
MAX_CONTAINING_SIZE = 15
99

10+
1011
class ArrayTriangles(AbstractTriangles):
1112
def __init__(
1213
self,
Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
1-
from autoconf.jax_wrapper import np
2-
from autoarray.structures.triangles.array import ArrayTriangles
3-
from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles
4-
1+
import jax.numpy as jnp
52
from matplotlib import pyplot as plt
6-
7-
83
import pytest
94

5+
from autoarray.structures.triangles.array import ArrayTriangles
6+
from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles
7+
108

119
@pytest.fixture
1210
def plot():
1311
plt.figure(figsize=(8, 8))
1412

1513
def plot(triangles, color="black"):
1614
for triangle in triangles:
17-
triangle = np.array(triangle)
18-
triangle = np.append(triangle, np.array([triangle[0]]), axis=0)
15+
triangle = jnp.array(triangle)
16+
triangle = jnp.append(triangle, jnp.array([triangle[0]]), axis=0)
1917
plt.plot(triangle[:, 0], triangle[:, 1], "o-", color=color)
2018

2119
yield plot
@@ -26,27 +24,27 @@ def plot(triangles, color="black"):
2624
@pytest.fixture
2725
def compare_with_nans():
2826
def compare_with_nans_(arr1, arr2):
29-
nan_mask1 = np.isnan(arr1)
30-
nan_mask2 = np.isnan(arr2)
27+
nan_mask1 = jnp.isnan(arr1)
28+
nan_mask2 = jnp.isnan(arr2)
3129

3230
arr1 = arr1[~nan_mask1]
3331
arr2 = arr2[~nan_mask2]
3432

35-
return np.all(arr1 == arr2)
33+
return jnp.all(arr1 == arr2)
3634

3735
return compare_with_nans_
3836

3937

4038
@pytest.fixture
4139
def triangles():
4240
return ArrayTriangles(
43-
indices=np.array(
41+
indices=jnp.array(
4442
[
4543
[0, 1, 2],
4644
[1, 2, 3],
4745
]
4846
),
49-
vertices=np.array(
47+
vertices=jnp.array(
5048
[
5149
[0.0, 0.0],
5250
[1.0, 0.0],
@@ -60,14 +58,14 @@ def triangles():
6058
@pytest.fixture
6159
def one_triangle():
6260
return CoordinateArrayTriangles(
63-
coordinates=np.array([[0, 0]]),
61+
coordinates=jnp.array([[0, 0]]),
6462
side_length=1.0,
6563
)
6664

6765

6866
@pytest.fixture
6967
def two_triangles():
7068
return CoordinateArrayTriangles(
71-
coordinates=np.array([[0, 0], [1, 0]]),
69+
coordinates=jnp.array([[0, 0], [1, 0]]),
7270
side_length=1.0,
7371
)

test_autoarray/structures/triangles/test_coordinate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import numpy as np
2+
from jax.tree_util import register_pytree_node_class
23

34
import pytest
45

56
from autoarray.structures.triangles.abstract import HEIGHT_FACTOR
6-
from autoarray.structures.triangles.shape import Point
77

88
from autoarray.structures.triangles.coordinate_array import (
99
CoordinateArrayTriangles,
1010
)
1111

12+
CoordinateArrayTriangles = register_pytree_node_class(CoordinateArrayTriangles)
13+
1214

1315
def test__two(two_triangles):
1416

test_autoarray/structures/triangles/test_jax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ArrayTriangles = register_pytree_node_class(ArrayTriangles)
1111
Point = register_pytree_node_class(Point)
1212

13+
1314
@pytest.fixture
1415
def triangles():
1516
return ArrayTriangles(

0 commit comments

Comments
 (0)