1+ import numpy as np
12import jax .numpy as jnp
23import jax
34
1213@register_pytree_node_class
1314class CoordinateArrayTriangles :
1415
16+ def __init__ (
17+ self ,
18+ coordinates : np .ndarray ,
19+ side_length : float = 1.0 ,
20+ x_offset : float = 0.0 ,
21+ y_offset : float = 0.0 ,
22+ flipped : bool = False ,
23+ ):
24+ """
25+ Represents a set of triangles by integer coordinates.
26+
27+ Parameters
28+ ----------
29+ coordinates
30+ Integer x y coordinates for each triangle.
31+ side_length
32+ The side length of the triangles.
33+ flipped
34+ Whether the triangles are flipped upside down.
35+ y_offset
36+ An y_offset to apply to the y coordinates so that up-sampled triangles align.
37+ """
38+ self .coordinates = coordinates
39+ self .side_length = side_length
40+ self .flipped = flipped
41+
42+ self .scaling_factors = jnp .array (
43+ [0.5 * side_length , HEIGHT_FACTOR * side_length ]
44+ )
45+ self .x_offset = x_offset
46+ self .y_offset = y_offset
47+
1548 @classmethod
1649 def for_limits_and_scale (
1750 cls ,
@@ -63,6 +96,12 @@ def tree_unflatten(cls, aux_data, children):
6396 """
6497 return cls (* children , flipped = aux_data [0 ])
6598
99+ def __len__ (self ):
100+ return jnp .count_nonzero (~ jnp .isnan (self .coordinates ).any (axis = 1 ))
101+
102+ def __iter__ (self ):
103+ return iter (self .triangles )
104+
66105 @property
67106 def centres (self ) -> jnp .ndarray :
68107 """
@@ -73,6 +112,48 @@ def centres(self) -> jnp.ndarray:
73112 )
74113 return centres
75114
115+ @cached_property
116+ def vertex_coordinates (self ) -> np .ndarray :
117+ """
118+ The vertices of the triangles as an Nx3x2 array.
119+ """
120+ coordinates = self .coordinates
121+ return jnp .concatenate (
122+ [
123+ coordinates + self .flip_array * np .array ([0 , 1 ], dtype = np .int32 ),
124+ coordinates + self .flip_array * np .array ([1 , - 1 ], dtype = np .int32 ),
125+ coordinates + self .flip_array * np .array ([- 1 , - 1 ], dtype = np .int32 ),
126+ ],
127+ dtype = np .int32 ,
128+ )
129+
130+ @cached_property
131+ def triangles (self ) -> np .ndarray :
132+ """
133+ The vertices of the triangles as an Nx3x2 array.
134+ """
135+ centres = self .centres
136+ return jnp .stack (
137+ (
138+ centres
139+ + self .flip_array
140+ * jnp .array (
141+ [0.0 , 0.5 * self .side_length * HEIGHT_FACTOR ],
142+ ),
143+ centres
144+ + self .flip_array
145+ * jnp .array (
146+ [0.5 * self .side_length , - 0.5 * self .side_length * HEIGHT_FACTOR ]
147+ ),
148+ centres
149+ + self .flip_array
150+ * jnp .array (
151+ [- 0.5 * self .side_length , - 0.5 * self .side_length * HEIGHT_FACTOR ]
152+ ),
153+ ),
154+ axis = 1 ,
155+ )
156+
76157 @cached_property
77158 def flip_mask (self ) -> jnp .ndarray :
78159 """
@@ -93,9 +174,6 @@ def flip_array(self) -> jnp.ndarray:
93174 array = jnp .where (self .flip_mask , - 1 , 1 )
94175 return array [:, None ]
95176
96- def __iter__ (self ):
97- return iter (self .triangles )
98-
99177 def up_sample (self ) -> "CoordinateArrayTriangles" :
100178 """
101179 Up-sample the triangles by adding a new vertex at the midpoint of each edge.
@@ -226,4 +304,26 @@ def for_indexes(self, indexes: jnp.ndarray) -> "CoordinateArrayTriangles":
226304 y_offset = self .y_offset ,
227305 x_offset = self .x_offset ,
228306 flipped = self .flipped ,
229- )
307+ )
308+
309+ @property
310+ def vertices (self ) -> np .ndarray :
311+ """
312+ The unique vertices of the triangles.
313+ """
314+ return self ._vertices_and_indices [0 ]
315+
316+ @property
317+ def indices (self ) -> np .ndarray :
318+ """
319+ The indices of the vertices of the triangles.
320+ """
321+ return self ._vertices_and_indices [1 ]
322+
323+ @property
324+ def means (self ):
325+ return jnp .mean (self .triangles , axis = 1 )
326+
327+ @property
328+ def area (self ):
329+ return (3 ** 0.5 / 4 * self .side_length ** 2 ) * len (self )
0 commit comments