Skip to content

Commit d83f4ba

Browse files
Jammy2211Jammy2211
authored andcommitted
simplfiied coordinate_array and removed support for numpy
1 parent 20ceaca commit d83f4ba

File tree

5 files changed

+106
-169
lines changed

5 files changed

+106
-169
lines changed

autoarray/structures/triangles/coordinate_array/jax_coordinate_array.py renamed to autoarray/structures/triangles/coordinate_array.py

Lines changed: 104 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import jax.numpy as jnp
23
import jax
34

@@ -12,6 +13,38 @@
1213
@register_pytree_node_class
1314
class CoordinateArrayTriangles:
1415

16+
def __init__(
17+
self,
18+
coordinates: np.ndarray,
19+
side_length: float = 1.0,
20+
x_offset: float = 0.0,
21+
y_offset: float = 0.0,
22+
flipped: bool = False,
23+
):
24+
"""
25+
Represents a set of triangles by integer coordinates.
26+
27+
Parameters
28+
----------
29+
coordinates
30+
Integer x y coordinates for each triangle.
31+
side_length
32+
The side length of the triangles.
33+
flipped
34+
Whether the triangles are flipped upside down.
35+
y_offset
36+
An y_offset to apply to the y coordinates so that up-sampled triangles align.
37+
"""
38+
self.coordinates = coordinates
39+
self.side_length = side_length
40+
self.flipped = flipped
41+
42+
self.scaling_factors = jnp.array(
43+
[0.5 * side_length, HEIGHT_FACTOR * side_length]
44+
)
45+
self.x_offset = x_offset
46+
self.y_offset = y_offset
47+
1548
@classmethod
1649
def for_limits_and_scale(
1750
cls,
@@ -63,6 +96,12 @@ def tree_unflatten(cls, aux_data, children):
6396
"""
6497
return cls(*children, flipped=aux_data[0])
6598

99+
def __len__(self):
100+
return jnp.count_nonzero(~jnp.isnan(self.coordinates).any(axis=1))
101+
102+
def __iter__(self):
103+
return iter(self.triangles)
104+
66105
@property
67106
def centres(self) -> jnp.ndarray:
68107
"""
@@ -73,6 +112,48 @@ def centres(self) -> jnp.ndarray:
73112
)
74113
return centres
75114

115+
@cached_property
116+
def vertex_coordinates(self) -> np.ndarray:
117+
"""
118+
The vertices of the triangles as an Nx3x2 array.
119+
"""
120+
coordinates = self.coordinates
121+
return jnp.concatenate(
122+
[
123+
coordinates + self.flip_array * np.array([0, 1], dtype=np.int32),
124+
coordinates + self.flip_array * np.array([1, -1], dtype=np.int32),
125+
coordinates + self.flip_array * np.array([-1, -1], dtype=np.int32),
126+
],
127+
dtype=np.int32,
128+
)
129+
130+
@cached_property
131+
def triangles(self) -> np.ndarray:
132+
"""
133+
The vertices of the triangles as an Nx3x2 array.
134+
"""
135+
centres = self.centres
136+
return jnp.stack(
137+
(
138+
centres
139+
+ self.flip_array
140+
* jnp.array(
141+
[0.0, 0.5 * self.side_length * HEIGHT_FACTOR],
142+
),
143+
centres
144+
+ self.flip_array
145+
* jnp.array(
146+
[0.5 * self.side_length, -0.5 * self.side_length * HEIGHT_FACTOR]
147+
),
148+
centres
149+
+ self.flip_array
150+
* jnp.array(
151+
[-0.5 * self.side_length, -0.5 * self.side_length * HEIGHT_FACTOR]
152+
),
153+
),
154+
axis=1,
155+
)
156+
76157
@cached_property
77158
def flip_mask(self) -> jnp.ndarray:
78159
"""
@@ -93,9 +174,6 @@ def flip_array(self) -> jnp.ndarray:
93174
array = jnp.where(self.flip_mask, -1, 1)
94175
return array[:, None]
95176

96-
def __iter__(self):
97-
return iter(self.triangles)
98-
99177
def up_sample(self) -> "CoordinateArrayTriangles":
100178
"""
101179
Up-sample the triangles by adding a new vertex at the midpoint of each edge.
@@ -226,4 +304,26 @@ def for_indexes(self, indexes: jnp.ndarray) -> "CoordinateArrayTriangles":
226304
y_offset=self.y_offset,
227305
x_offset=self.x_offset,
228306
flipped=self.flipped,
229-
)
307+
)
308+
309+
@property
310+
def vertices(self) -> np.ndarray:
311+
"""
312+
The unique vertices of the triangles.
313+
"""
314+
return self._vertices_and_indices[0]
315+
316+
@property
317+
def indices(self) -> np.ndarray:
318+
"""
319+
The indices of the vertices of the triangles.
320+
"""
321+
return self._vertices_and_indices[1]
322+
323+
@property
324+
def means(self):
325+
return jnp.mean(self.triangles, axis=1)
326+
327+
@property
328+
def area(self):
329+
return (3**0.5 / 4 * self.side_length**2) * len(self)

autoarray/structures/triangles/coordinate_array/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

autoarray/structures/triangles/coordinate_array/abstract_coordinate_array.py

Lines changed: 0 additions & 162 deletions
This file was deleted.

test_autoarray/structures/triangles/coordinate/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44

5-
from autoarray.structures.triangles.coordinate_array import JAXCoordinateArrayTriangles as CoordinateArrayTriangles
5+
from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles
66

77

88
@pytest.fixture

test_autoarray/structures/triangles/coordinate/test_coordinate_implementation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44

55
from autoarray.structures.triangles.abstract import HEIGHT_FACTOR
6-
from autoarray.structures.triangles.coordinate_array import JAXCoordinateArrayTriangles as CoordinateArrayTriangles
6+
from autoarray.structures.triangles.coordinate_array import CoordinateArrayTriangles
77
from autoarray.structures.triangles.shape import Point
88

99

0 commit comments

Comments
 (0)