File tree Expand file tree Collapse file tree 1 file changed +10
-7
lines changed
Expand file tree Collapse file tree 1 file changed +10
-7
lines changed Original file line number Diff line number Diff line change @@ -374,17 +374,20 @@ def transform_grid_2d_to_reference_frame(
374374 grid
375375 The 2d grid of (y, x) coordinates which are transformed to a new reference frame.
376376 """
377- shifted_grid_2d = grid_2d - jnp .array (centre )
377+ grid_2d = jnp .asarray (grid_2d )
378+ centre = jnp .asarray (centre )
378379
379- radius = jnp . sqrt ( jnp . sum ( jnp . square ( shifted_grid_2d ), axis = 1 ))
380- theta_coordinate_to_profile = jnp . arctan2 (
381- shifted_grid_2d [:, 0 ], shifted_grid_2d [:, 1 ]
382- ) - jnp . radians ( angle )
380+ # Inject a tiny dynamic dependency on `angle` to prevent heavy constant folding
381+ # (adds zero at runtime; negligible cost)
382+ dynamic_zero = jnp . zeros_like ( grid_2d ) * angle
383+ shifted = ( grid_2d + dynamic_zero ) - centre
383384
385+ radius = jnp .sqrt (jnp .sum (shifted * shifted , axis = 1 ))
386+ theta = jnp .arctan2 (shifted [:, 0 ], shifted [:, 1 ]) - jnp .deg2rad (angle )
384387 return jnp .vstack (
385388 [
386- radius * jnp .sin (theta_coordinate_to_profile ),
387- radius * jnp .cos (theta_coordinate_to_profile ),
389+ radius * jnp .sin (theta ),
390+ radius * jnp .cos (theta ),
388391 ]
389392 ).T
390393
You can’t perform that action at this time.
0 commit comments