1- from jax import numpy as np
1+ import jax . numpy as jnp
22import jax
33
44from autoarray .structures .triangles .abstract import HEIGHT_FACTOR
5- from autoarray .structures .triangles .coordinate_array .abstract_coordinate_array import (
6- AbstractCoordinateArray ,
7- )
85from autoarray .structures .triangles .array .jax_array import ArrayTriangles
96from autoarray .numpy_wrapper import register_pytree_node_class
107from autoconf import cached_property
1310
1411
1512@register_pytree_node_class
16- class CoordinateArrayTriangles (AbstractCoordinateArray ):
17- @property
18- def numpy (self ):
19- return jax .numpy
13+ class CoordinateArrayTriangles :
2014
2115 @classmethod
2216 def for_limits_and_scale (
@@ -38,7 +32,7 @@ def for_limits_and_scale(
3832 coordinates .append ([x , y ])
3933
4034 return cls (
41- coordinates = np .array (coordinates ),
35+ coordinates = jnp .array (coordinates ),
4236 side_length = scale ,
4337 )
4438
@@ -70,17 +64,17 @@ def tree_unflatten(cls, aux_data, children):
7064 return cls (* children , flipped = aux_data [0 ])
7165
7266 @property
73- def centres (self ) -> np .ndarray :
67+ def centres (self ) -> jnp .ndarray :
7468 """
7569 The centres of the triangles.
7670 """
77- centres = self .scaling_factors * self .coordinates + np .array (
71+ centres = self .scaling_factors * self .coordinates + jnp .array (
7872 [self .x_offset , self .y_offset ]
7973 )
8074 return centres
8175
8276 @cached_property
83- def flip_mask (self ) -> np .ndarray :
77+ def flip_mask (self ) -> jnp .ndarray :
8478 """
8579 A mask for the triangles that are flipped.
8680
@@ -92,11 +86,11 @@ def flip_mask(self) -> np.ndarray:
9286 return mask
9387
9488 @cached_property
95- def flip_array (self ) -> np .ndarray :
89+ def flip_array (self ) -> jnp .ndarray :
9690 """
9791 An array of 1s and -1s to flip the triangles.
9892 """
99- array = np .where (self .flip_mask , - 1 , 1 )
93+ array = jnp .where (self .flip_mask , - 1 , 1 )
10094 return array [:, None ]
10195
10296 def __iter__ (self ):
@@ -113,11 +107,11 @@ def up_sample(self) -> "CoordinateArrayTriangles":
113107
114108 n = coordinates .shape [0 ]
115109
116- shift0 = np .zeros ((n , 2 ))
117- shift3 = np .tile (np .array ([0 , 1 ]), (n , 1 ))
118- shift1 = np .stack ([np .ones (n ), np .where (flip_mask , 1 , 0 )], axis = 1 )
119- shift2 = np .stack ([- np .ones (n ), np .where (flip_mask , 1 , 0 )], axis = 1 )
120- shifts = np .stack ([shift0 , shift1 , shift2 , shift3 ], axis = 1 )
110+ shift0 = jnp .zeros ((n , 2 ))
111+ shift3 = jnp .tile (jnp .array ([0 , 1 ]), (n , 1 ))
112+ shift1 = jnp .stack ([jnp .ones (n ), jnp .where (flip_mask , 1 , 0 )], axis = 1 )
113+ shift2 = jnp .stack ([- jnp .ones (n ), jnp .where (flip_mask , 1 , 0 )], axis = 1 )
114+ shifts = jnp .stack ([shift0 , shift1 , shift2 , shift3 ], axis = 1 )
121115
122116 coordinates_expanded = coordinates [:, None , :]
123117 new_coordinates = coordinates_expanded + shifts
@@ -140,27 +134,27 @@ def neighborhood(self) -> "CoordinateArrayTriangles":
140134 coordinates = self .coordinates
141135 flip_mask = self .flip_mask
142136
143- shift0 = np .zeros ((coordinates .shape [0 ], 2 ))
144- shift1 = np .tile (np .array ([1 , 0 ]), (coordinates .shape [0 ], 1 ))
145- shift2 = np .tile (np .array ([- 1 , 0 ]), (coordinates .shape [0 ], 1 ))
146- shift3 = np .where (
137+ shift0 = jnp .zeros ((coordinates .shape [0 ], 2 ))
138+ shift1 = jnp .tile (jnp .array ([1 , 0 ]), (coordinates .shape [0 ], 1 ))
139+ shift2 = jnp .tile (jnp .array ([- 1 , 0 ]), (coordinates .shape [0 ], 1 ))
140+ shift3 = jnp .where (
147141 flip_mask [:, None ],
148- np .tile (np .array ([0 , 1 ]), (coordinates .shape [0 ], 1 )),
149- np .tile (np .array ([0 , - 1 ]), (coordinates .shape [0 ], 1 )),
142+ jnp .tile (jnp .array ([0 , 1 ]), (coordinates .shape [0 ], 1 )),
143+ jnp .tile (jnp .array ([0 , - 1 ]), (coordinates .shape [0 ], 1 )),
150144 )
151145
152- shifts = np .stack ([shift0 , shift1 , shift2 , shift3 ], axis = 1 )
146+ shifts = jnp .stack ([shift0 , shift1 , shift2 , shift3 ], axis = 1 )
153147
154148 coordinates_expanded = coordinates [:, None , :]
155149 new_coordinates = coordinates_expanded + shifts
156150 new_coordinates = new_coordinates .reshape (- 1 , 2 )
157151
158152 expected_size = 4 * coordinates .shape [0 ]
159- unique_coords , indices = np .unique (
153+ unique_coords , indices = jnp .unique (
160154 new_coordinates ,
161155 axis = 0 ,
162156 size = expected_size ,
163- fill_value = np .nan ,
157+ fill_value = jnp .nan ,
164158 return_index = True ,
165159 )
166160
@@ -175,22 +169,22 @@ def neighborhood(self) -> "CoordinateArrayTriangles":
175169 @cached_property
176170 def _vertices_and_indices (self ):
177171 flat_triangles = self .triangles .reshape (- 1 , 2 )
178- vertices , inverse_indices = np .unique (
172+ vertices , inverse_indices = jnp .unique (
179173 flat_triangles ,
180174 axis = 0 ,
181175 return_inverse = True ,
182176 size = 3 * self .coordinates .shape [0 ],
183177 equal_nan = True ,
184- fill_value = np .nan ,
178+ fill_value = jnp .nan ,
185179 )
186180
187- nan_mask = np .isnan (vertices ).any (axis = 1 )
188- inverse_indices = np .where (nan_mask [inverse_indices ], - 1 , inverse_indices )
181+ nan_mask = jnp .isnan (vertices ).any (axis = 1 )
182+ inverse_indices = jnp .where (nan_mask [inverse_indices ], - 1 , inverse_indices )
189183
190184 indices = inverse_indices .reshape (- 1 , 3 )
191185 return vertices , indices
192186
193- def with_vertices (self , vertices : np .ndarray ) -> ArrayTriangles :
187+ def with_vertices (self , vertices : jnp .ndarray ) -> ArrayTriangles :
194188 """
195189 Create a new set of triangles with the vertices replaced.
196190
@@ -208,7 +202,7 @@ def with_vertices(self, vertices: np.ndarray) -> ArrayTriangles:
208202 vertices = vertices ,
209203 )
210204
211- def for_indexes (self , indexes : np .ndarray ) -> "CoordinateArrayTriangles" :
205+ def for_indexes (self , indexes : jnp .ndarray ) -> "CoordinateArrayTriangles" :
212206 """
213207 Create a new CoordinateArrayTriangles containing triangles corresponding to the given indexes
214208
@@ -222,17 +216,14 @@ def for_indexes(self, indexes: np.ndarray) -> "CoordinateArrayTriangles":
222216 The new CoordinateArrayTriangles instance.
223217 """
224218 mask = indexes == - 1
225- safe_indexes = np .where (mask , 0 , indexes )
226- coordinates = np .take (self .coordinates , safe_indexes , axis = 0 )
227- coordinates = np .where (mask [:, None ], np .nan , coordinates )
219+ safe_indexes = jnp .where (mask , 0 , indexes )
220+ coordinates = jnp .take (self .coordinates , safe_indexes , axis = 0 )
221+ coordinates = jnp .where (mask [:, None ], jnp .nan , coordinates )
228222
229223 return CoordinateArrayTriangles (
230224 coordinates = coordinates ,
231225 side_length = self .side_length ,
232226 y_offset = self .y_offset ,
233227 x_offset = self .x_offset ,
234228 flipped = self .flipped ,
235- )
236-
237- def containing_indices (self , shape : np .ndarray ) -> np .ndarray :
238- raise NotImplementedError ("JAX ArrayTriangles are used for this method." )
229+ )
0 commit comments