Skip to content

Commit 67ad463

Browse files
Jammy2211Jammy2211
authored andcommitted
speed up JAX compiltation a lot via cleveer trick
1 parent 0d25172 commit 67ad463

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

autoarray/geometry/geometry_util.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -374,17 +374,20 @@ def transform_grid_2d_to_reference_frame(
374374
grid
375375
The 2d grid of (y, x) coordinates which are transformed to a new reference frame.
376376
"""
377-
shifted_grid_2d = grid_2d - jnp.array(centre)
377+
grid_2d = jnp.asarray(grid_2d)
378+
centre = jnp.asarray(centre)
378379

379-
radius = jnp.sqrt(jnp.sum(jnp.square(shifted_grid_2d), axis=1))
380-
theta_coordinate_to_profile = jnp.arctan2(
381-
shifted_grid_2d[:, 0], shifted_grid_2d[:, 1]
382-
) - jnp.radians(angle)
380+
# Inject a tiny dynamic dependency on `angle` to prevent heavy constant folding
381+
# (adds zero at runtime; negligible cost)
382+
dynamic_zero = jnp.zeros_like(grid_2d) * angle
383+
shifted = (grid_2d + dynamic_zero) - centre
383384

385+
radius = jnp.sqrt(jnp.sum(shifted * shifted, axis=1))
386+
theta = jnp.arctan2(shifted[:, 0], shifted[:, 1]) - jnp.deg2rad(angle)
384387
return jnp.vstack(
385388
[
386-
radius * jnp.sin(theta_coordinate_to_profile),
387-
radius * jnp.cos(theta_coordinate_to_profile),
389+
radius * jnp.sin(theta),
390+
radius * jnp.cos(theta),
388391
]
389392
).T
390393

0 commit comments

Comments
 (0)