Skip to content

Commit 20ceaca

Browse files
Jammy2211Jammy2211
authored andcommitted
JAx CoordinateArrayTriangles has explicit JAX use now
1 parent 3933b55 commit 20ceaca

File tree

3 files changed

+32
-56
lines changed

3 files changed

+32
-56
lines changed

autoarray/structures/triangles/abstract.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -122,21 +122,6 @@ def for_indexes(self, indexes: np.ndarray) -> "AbstractTriangles":
122122
The new ArrayTriangles instance.
123123
"""
124124

125-
@abstractmethod
126-
def containing_indices(self, shape: Shape) -> np.ndarray:
127-
"""
128-
Find the triangles that insect with a given shape.
129-
130-
Parameters
131-
----------
132-
shape
133-
The shape
134-
135-
Returns
136-
-------
137-
The indices of triangles that intersect the shape.
138-
"""
139-
140125
@abstractmethod
141126
def neighborhood(self) -> "AbstractTriangles":
142127
"""

autoarray/structures/triangles/array.py

Whitespace-only changes.

autoarray/structures/triangles/coordinate_array/jax_coordinate_array.py

Lines changed: 32 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
1-
from jax import numpy as np
1+
import jax.numpy as jnp
22
import jax
33

44
from autoarray.structures.triangles.abstract import HEIGHT_FACTOR
5-
from autoarray.structures.triangles.coordinate_array.abstract_coordinate_array import (
6-
AbstractCoordinateArray,
7-
)
85
from autoarray.structures.triangles.array.jax_array import ArrayTriangles
96
from autoarray.numpy_wrapper import register_pytree_node_class
107
from autoconf import cached_property
@@ -13,10 +10,7 @@
1310

1411

1512
@register_pytree_node_class
16-
class CoordinateArrayTriangles(AbstractCoordinateArray):
17-
@property
18-
def numpy(self):
19-
return jax.numpy
13+
class CoordinateArrayTriangles:
2014

2115
@classmethod
2216
def for_limits_and_scale(
@@ -38,7 +32,7 @@ def for_limits_and_scale(
3832
coordinates.append([x, y])
3933

4034
return cls(
41-
coordinates=np.array(coordinates),
35+
coordinates=jnp.array(coordinates),
4236
side_length=scale,
4337
)
4438

@@ -70,17 +64,17 @@ def tree_unflatten(cls, aux_data, children):
7064
return cls(*children, flipped=aux_data[0])
7165

7266
@property
73-
def centres(self) -> np.ndarray:
67+
def centres(self) -> jnp.ndarray:
7468
"""
7569
The centres of the triangles.
7670
"""
77-
centres = self.scaling_factors * self.coordinates + np.array(
71+
centres = self.scaling_factors * self.coordinates + jnp.array(
7872
[self.x_offset, self.y_offset]
7973
)
8074
return centres
8175

8276
@cached_property
83-
def flip_mask(self) -> np.ndarray:
77+
def flip_mask(self) -> jnp.ndarray:
8478
"""
8579
A mask for the triangles that are flipped.
8680
@@ -92,11 +86,11 @@ def flip_mask(self) -> np.ndarray:
9286
return mask
9387

9488
@cached_property
95-
def flip_array(self) -> np.ndarray:
89+
def flip_array(self) -> jnp.ndarray:
9690
"""
9791
An array of 1s and -1s to flip the triangles.
9892
"""
99-
array = np.where(self.flip_mask, -1, 1)
93+
array = jnp.where(self.flip_mask, -1, 1)
10094
return array[:, None]
10195

10296
def __iter__(self):
@@ -113,11 +107,11 @@ def up_sample(self) -> "CoordinateArrayTriangles":
113107

114108
n = coordinates.shape[0]
115109

116-
shift0 = np.zeros((n, 2))
117-
shift3 = np.tile(np.array([0, 1]), (n, 1))
118-
shift1 = np.stack([np.ones(n), np.where(flip_mask, 1, 0)], axis=1)
119-
shift2 = np.stack([-np.ones(n), np.where(flip_mask, 1, 0)], axis=1)
120-
shifts = np.stack([shift0, shift1, shift2, shift3], axis=1)
110+
shift0 = jnp.zeros((n, 2))
111+
shift3 = jnp.tile(jnp.array([0, 1]), (n, 1))
112+
shift1 = jnp.stack([jnp.ones(n), jnp.where(flip_mask, 1, 0)], axis=1)
113+
shift2 = jnp.stack([-jnp.ones(n), jnp.where(flip_mask, 1, 0)], axis=1)
114+
shifts = jnp.stack([shift0, shift1, shift2, shift3], axis=1)
121115

122116
coordinates_expanded = coordinates[:, None, :]
123117
new_coordinates = coordinates_expanded + shifts
@@ -140,27 +134,27 @@ def neighborhood(self) -> "CoordinateArrayTriangles":
140134
coordinates = self.coordinates
141135
flip_mask = self.flip_mask
142136

143-
shift0 = np.zeros((coordinates.shape[0], 2))
144-
shift1 = np.tile(np.array([1, 0]), (coordinates.shape[0], 1))
145-
shift2 = np.tile(np.array([-1, 0]), (coordinates.shape[0], 1))
146-
shift3 = np.where(
137+
shift0 = jnp.zeros((coordinates.shape[0], 2))
138+
shift1 = jnp.tile(jnp.array([1, 0]), (coordinates.shape[0], 1))
139+
shift2 = jnp.tile(jnp.array([-1, 0]), (coordinates.shape[0], 1))
140+
shift3 = jnp.where(
147141
flip_mask[:, None],
148-
np.tile(np.array([0, 1]), (coordinates.shape[0], 1)),
149-
np.tile(np.array([0, -1]), (coordinates.shape[0], 1)),
142+
jnp.tile(jnp.array([0, 1]), (coordinates.shape[0], 1)),
143+
jnp.tile(jnp.array([0, -1]), (coordinates.shape[0], 1)),
150144
)
151145

152-
shifts = np.stack([shift0, shift1, shift2, shift3], axis=1)
146+
shifts = jnp.stack([shift0, shift1, shift2, shift3], axis=1)
153147

154148
coordinates_expanded = coordinates[:, None, :]
155149
new_coordinates = coordinates_expanded + shifts
156150
new_coordinates = new_coordinates.reshape(-1, 2)
157151

158152
expected_size = 4 * coordinates.shape[0]
159-
unique_coords, indices = np.unique(
153+
unique_coords, indices = jnp.unique(
160154
new_coordinates,
161155
axis=0,
162156
size=expected_size,
163-
fill_value=np.nan,
157+
fill_value=jnp.nan,
164158
return_index=True,
165159
)
166160

@@ -175,22 +169,22 @@ def neighborhood(self) -> "CoordinateArrayTriangles":
175169
@cached_property
176170
def _vertices_and_indices(self):
177171
flat_triangles = self.triangles.reshape(-1, 2)
178-
vertices, inverse_indices = np.unique(
172+
vertices, inverse_indices = jnp.unique(
179173
flat_triangles,
180174
axis=0,
181175
return_inverse=True,
182176
size=3 * self.coordinates.shape[0],
183177
equal_nan=True,
184-
fill_value=np.nan,
178+
fill_value=jnp.nan,
185179
)
186180

187-
nan_mask = np.isnan(vertices).any(axis=1)
188-
inverse_indices = np.where(nan_mask[inverse_indices], -1, inverse_indices)
181+
nan_mask = jnp.isnan(vertices).any(axis=1)
182+
inverse_indices = jnp.where(nan_mask[inverse_indices], -1, inverse_indices)
189183

190184
indices = inverse_indices.reshape(-1, 3)
191185
return vertices, indices
192186

193-
def with_vertices(self, vertices: np.ndarray) -> ArrayTriangles:
187+
def with_vertices(self, vertices: jnp.ndarray) -> ArrayTriangles:
194188
"""
195189
Create a new set of triangles with the vertices replaced.
196190
@@ -208,7 +202,7 @@ def with_vertices(self, vertices: np.ndarray) -> ArrayTriangles:
208202
vertices=vertices,
209203
)
210204

211-
def for_indexes(self, indexes: np.ndarray) -> "CoordinateArrayTriangles":
205+
def for_indexes(self, indexes: jnp.ndarray) -> "CoordinateArrayTriangles":
212206
"""
213207
Create a new CoordinateArrayTriangles containing triangles corresponding to the given indexes
214208
@@ -222,17 +216,14 @@ def for_indexes(self, indexes: np.ndarray) -> "CoordinateArrayTriangles":
222216
The new CoordinateArrayTriangles instance.
223217
"""
224218
mask = indexes == -1
225-
safe_indexes = np.where(mask, 0, indexes)
226-
coordinates = np.take(self.coordinates, safe_indexes, axis=0)
227-
coordinates = np.where(mask[:, None], np.nan, coordinates)
219+
safe_indexes = jnp.where(mask, 0, indexes)
220+
coordinates = jnp.take(self.coordinates, safe_indexes, axis=0)
221+
coordinates = jnp.where(mask[:, None], jnp.nan, coordinates)
228222

229223
return CoordinateArrayTriangles(
230224
coordinates=coordinates,
231225
side_length=self.side_length,
232226
y_offset=self.y_offset,
233227
x_offset=self.x_offset,
234228
flipped=self.flipped,
235-
)
236-
237-
def containing_indices(self, shape: np.ndarray) -> np.ndarray:
238-
raise NotImplementedError("JAX ArrayTriangles are used for this method.")
229+
)

0 commit comments

Comments
 (0)