Skip to content

Commit a14989d

Browse files
Jammy2211Jammy2211
authored andcommitted
colemans new code implemented
1 parent 6fbb976 commit a14989d

File tree

4 files changed

+314
-84
lines changed

4 files changed

+314
-84
lines changed

autoarray/inversion/pixelization/mappers/mapper_util.py

Lines changed: 26 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2,90 +2,27 @@
22
import numpy as np
33
from typing import Tuple
44

5+
def transform_and_inv_transform_from(source_grid_scaled, mesh_weight_map, xp=np):
56

6-
def forward_interp(xp, yp, x):
7-
8-
import jax
9-
import jax.numpy as jnp
10-
11-
return jax.vmap(jnp.interp, in_axes=(1, 1, 1, None, None), out_axes=(1))(
12-
x, xp, yp, 0, 1
13-
)
14-
15-
16-
def reverse_interp(xp, yp, x):
17-
import jax
18-
import jax.numpy as jnp
19-
20-
return jax.vmap(jnp.interp, in_axes=(1, 1, 1), out_axes=(1))(x, xp, yp)
21-
22-
23-
def forward_interp_np(xp, yp, x):
24-
"""
25-
xp: (N, M)
26-
yp: (N, M)
27-
x : (M,) ← one x per column
28-
"""
29-
30-
if yp.ndim == 1 and xp.ndim == 2:
31-
yp = np.broadcast_to(yp[:, None], xp.shape)
32-
33-
K, M = x.shape
34-
35-
out = np.empty((K, 2), dtype=xp.dtype)
36-
37-
for j in range(2):
38-
out[:, j] = np.interp(x[:, j], xp[:, j], yp[:, j], left=0, right=1)
39-
40-
return out
41-
42-
43-
def reverse_interp_np(xp, yp, x):
44-
"""
45-
xp : (N,) or (N, M)
46-
yp : (N, M)
47-
x : (K, M) query points per column
48-
"""
49-
50-
# Ensure xp is 2D: (N, M)
51-
if xp.ndim == 1 and yp.ndim == 2: # (N, 1)
52-
xp = np.broadcast_to(xp[:, None], yp.shape)
53-
54-
# Shapes
55-
K, M = x.shape
56-
57-
# Output
58-
out = np.empty((K, 2), dtype=yp.dtype)
59-
60-
# Column-wise interpolation (cannot avoid this loop in pure NumPy)
61-
for j in range(2):
62-
out[:, j] = np.interp(x[:, j], xp[:, j], yp[:, j])
63-
64-
return out
65-
7+
if xp.__name__.startswith("jax"):
668

67-
def create_transforms(traced_points, mesh_weight_map=None, xp=np):
9+
from autoarray.inversion.pixelization.mappers.rectangular_interp import util_jax
6810

69-
N = traced_points.shape[0] # // 2
11+
transform, inv_transform = util_jax.create_transforms(
12+
source_grid_scaled,
13+
deg=21,
14+
mesh_weight_map=mesh_weight_map
15+
)
7016

71-
if mesh_weight_map is None:
72-
t = xp.arange(1, N + 1) / (N + 1)
73-
t = xp.stack([t, t], axis=1)
74-
sort_points = xp.sort(traced_points, axis=0) # [::2]
7517
else:
76-
sdx = xp.argsort(traced_points, axis=0)
77-
sort_points = xp.take_along_axis(traced_points, sdx, axis=0)
78-
t = xp.stack([mesh_weight_map, mesh_weight_map], axis=1)
79-
t = xp.take_along_axis(t, sdx, axis=0)
80-
t = xp.cumsum(t, axis=0)
8118

82-
if xp.__name__.startswith("jax"):
83-
transform = partial(forward_interp, sort_points, t)
84-
inv_transform = partial(reverse_interp, t, sort_points)
85-
return transform, inv_transform
19+
from autoarray.inversion.pixelization.mappers.rectangular_interp import util_np
20+
21+
transform, inv_transform = util_np.create_transforms(
22+
source_grid_scaled,
23+
mesh_weight_map=mesh_weight_map
24+
)
8625

87-
transform = partial(forward_interp_np, sort_points, t)
88-
inv_transform = partial(reverse_interp_np, t, sort_points)
8926
return transform, inv_transform
9027

9128

@@ -97,8 +34,10 @@ def adaptive_rectangular_transformed_grid_from(
9734
scale = source_plane_data_grid.std(axis=0).min()
9835
source_grid_scaled = (source_plane_data_grid - mu) / scale
9936

100-
transform, inv_transform = create_transforms(
101-
source_grid_scaled, mesh_weight_map=mesh_weight_map, xp=xp
37+
transform, inv_transform = transform_and_inv_transform_from(
38+
source_grid_scaled=source_grid_scaled,
39+
mesh_weight_map=mesh_weight_map,
40+
xp=xp
10241
)
10342

10443
def inv_full(U):
@@ -118,8 +57,10 @@ def adaptive_rectangular_areas_from(
11857
scale = source_plane_data_grid.std(axis=0).min()
11958
source_grid_scaled = (source_plane_data_grid - mu) / scale
12059

121-
transform, inv_transform = create_transforms(
122-
source_grid_scaled, mesh_weight_map=mesh_weight_map, xp=xp
60+
transform, inv_transform = transform_and_inv_transform_from(
61+
source_grid_scaled=source_grid_scaled,
62+
mesh_weight_map=mesh_weight_map,
63+
xp=xp
12364
)
12465

12566
def inv_full(U):
@@ -194,15 +135,16 @@ def adaptive_rectangular_mappings_weights_via_interpolation_from(
194135
The bilinear interpolation weights for each of the four neighboring pixels.
195136
Order: [w_bl, w_br, w_tl, w_tr].
196137
"""
197-
198138
# --- Step 1. Normalize grid ---
199139
mu = source_plane_data_grid.mean(axis=0)
200140
scale = source_plane_data_grid.std(axis=0).min()
201141
source_grid_scaled = (source_plane_data_grid - mu) / scale
202142

203143
# --- Step 2. Build transforms ---
204-
transform, inv_transform = create_transforms(
205-
source_grid_scaled, mesh_weight_map=mesh_weight_map, xp=xp
144+
transform, inv_transform = transform_and_inv_transform_from(
145+
source_grid_scaled=source_grid_scaled,
146+
mesh_weight_map=mesh_weight_map,
147+
xp=xp
206148
)
207149

208150
# --- Step 3. Transform oversampled grid into index space ---

autoarray/inversion/pixelization/mappers/rectangular_interp/__init__.py

Whitespace-only changes.
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
from functools import partial
2+
import numpy as np
3+
from typing import Tuple
4+
5+
6+
def create_transforms(source_grid_scaled, deg=11, mesh_weight_map=None):
7+
8+
import jax
9+
import jax.numpy as jnp
10+
from jax.tree_util import register_pytree_node_class
11+
12+
@jax.jit
13+
def interp1d(x, xp, fp): #, left=None, right=None):
14+
i = jnp.clip(
15+
jnp.searchsorted(
16+
xp,
17+
x,
18+
side='right',
19+
method='scan_unrolled'
20+
),
21+
1,
22+
len(xp) - 1
23+
)
24+
df = fp[i] - fp[i - 1]
25+
dx = xp[i] - xp[i - 1]
26+
delta = x - xp[i - 1]
27+
eps = jnp.finfo(xp.dtype).eps
28+
epsilon = jnp.nextafter(eps, jnp.inf) - eps
29+
30+
dx0 = jax.lax.abs(dx) <= epsilon # Prevent NaN gradients when `dx` is small.
31+
f = jnp.where(
32+
dx0,
33+
fp[i - 1],
34+
fp[i - 1] + (delta / jnp.where(dx0, 1, dx)) * df
35+
)
36+
37+
return f
38+
39+
@jax.custom_jvp
40+
def spline_invert(ip, x):
41+
# use a custom jvp because we are using cached values to get the spline faster
42+
# and this would not easily give the grad with respect to the poly coefs as written
43+
k_right = jnp.digitize(x, ip.x_low_res, method='scan_unrolled')
44+
k_left = k_right - 1
45+
46+
# jax's default out-of-bound index gives
47+
# correct result for point on the right most
48+
# edge of interpolation, no need to do anything
49+
# special for the boundary
50+
t = (x - ip.x_low_res[k_left]) / ip.delta_x[k_left]
51+
t2 = t ** 2
52+
t3 = t ** 3
53+
h00 = 2 * t3 - 3 * t2 + 1
54+
h10 = t3 - 2 * t2 + t
55+
h01 = -2 * t3 + 3 * t2
56+
h11 = t3 - t2
57+
return ip.y_low_res[k_left] * h00 + ip.y_low_res[k_right] * h01 + (
58+
ip.dy_low_res[k_left] * h10 + ip.dy_low_res[k_right] * h11) * ip.delta_x[k_left]
59+
60+
@spline_invert.defjvp
61+
def invert_poly_jvp(primals, tangents):
62+
# because this is the inverse of a polynomial it's
63+
# gradient can be written in terms of the gradient
64+
# of the polynomial evaluated at the output
65+
# this is easy to write down and avoids needing
66+
# to grad through the cubic spline inversion
67+
ip, x = primals
68+
ip_dot, x_dot = tangents
69+
primal_out = spline_invert(ip, x)
70+
d_dx = 1 / jnp.polyval(ip.dcoefs, primal_out)
71+
d_dcoefs = -jnp.vander(jnp.atleast_1d(primal_out), N=ip.coefs.shape[0])
72+
tangent_out = ((ip_dot.coefs * d_dcoefs).sum() + x_dot) * d_dx
73+
return primal_out, tangent_out
74+
75+
@register_pytree_node_class
76+
class InvertPolySpline:
77+
@staticmethod
78+
def v_polyder(c):
79+
return jax.vmap(
80+
jnp.polyder,
81+
in_axes=1,
82+
out_axes=1
83+
)(c)
84+
85+
@staticmethod
86+
def v_polyval(c, x):
87+
return jax.vmap(
88+
jnp.polyval,
89+
in_axes=(1, 1),
90+
out_axes=(1)
91+
)(c, x)
92+
93+
def __init__(self, coefs, lower_bound, upper_bound, low_res=150):
94+
# coefs Nx2
95+
# lower_bound 1x2
96+
# upper_bound 1x2
97+
# low_res int
98+
99+
# polynomial to inverse
100+
self.coefs = coefs
101+
102+
# get 1st derivative of polynomial
103+
self.dcoefs = InvertPolySpline.v_polyder(self.coefs)
104+
105+
# The bounds of the CDF
106+
# below will always be 0
107+
# above will always be 1
108+
self.lower_bound = jnp.atleast_2d(lower_bound)
109+
self.upper_bound = jnp.atleast_2d(upper_bound)
110+
111+
# low resolution grid of nodes for spline approx to the inverse function
112+
# cubic spline needs the function, derivative, and delta_x at each node
113+
self.low_res = low_res
114+
y_low_res = jnp.linspace(0, 1, low_res)
115+
self.y_low_res = jnp.stack([y_low_res, y_low_res], axis=1)
116+
self.x_low_res = InvertPolySpline.v_polyval(self.coefs, self.y_low_res)
117+
self.dy_low_res = 1 / InvertPolySpline.v_polyval(self.dcoefs, self.y_low_res)
118+
self.delta_x = jnp.diff(self.x_low_res, axis=0)
119+
120+
def __repr__(self):
121+
return f'InvertPoly(coefs={self.coefs}, lower_bound={self.lower_bound}, upper_bound={self.upper_bound})'
122+
123+
def tree_flatten(self):
124+
children = (
125+
self.coefs,
126+
self.dcoefs,
127+
self.y_low_res,
128+
self.x_low_res,
129+
self.dy_low_res,
130+
self.delta_x,
131+
self.lower_bound,
132+
self.upper_bound
133+
)
134+
aux_data = (self.low_res,)
135+
return (children, aux_data)
136+
137+
@classmethod
138+
def tree_unflatten(cls, aux_data, children):
139+
# return cls(*(children + aux_data))
140+
obj = object.__new__(InvertPolySpline)
141+
obj.coefs = children[0]
142+
obj.dcoefs = children[1]
143+
obj.y_low_res = children[2]
144+
obj.x_low_res = children[3]
145+
obj.dy_low_res = children[4]
146+
obj.delta_x = children[5]
147+
obj.lower_bound = children[6]
148+
obj.upper_bound = children[7]
149+
obj.low_res = aux_data[0]
150+
return obj
151+
152+
def fwd_transform(self, x):
153+
y = jax.vmap(
154+
spline_invert,
155+
in_axes=(1, 1),
156+
out_axes=(1)
157+
)(self, x)
158+
y = jnp.where(x <= self.lower_bound, 0.0, y)
159+
y = jnp.where(x >= self.upper_bound, 1.0, y)
160+
return jnp.clip(y, 0.0, 1.0)
161+
162+
def rev_transform(self, y):
163+
return InvertPolySpline.v_polyval(self.coefs, y)
164+
165+
v_polyfit = jax.vmap(jnp.polyfit, in_axes=(1, 1, None, None, None, 1), out_axes=(1))
166+
v_gradient = jax.vmap(jnp.gradient, in_axes=(1, 1), out_axes=1)
167+
168+
# inv_poly is a pytree, it can be returned from `jit` without issue :D
169+
@partial(jax.jit, static_argnames=('deg'))
170+
def create_transforms_spline(traced_points, deg=11, mesh_weight_map=None):
171+
172+
N = traced_points.shape[0] # // 2
173+
if mesh_weight_map is None:
174+
t = jnp.arange(1, N + 1) / (N + 1)
175+
t = jnp.stack([t, t], axis=1)
176+
sort_points = jnp.sort(traced_points, axis=0) # [::2]
177+
else:
178+
sdx = jnp.argsort(traced_points, axis=0)
179+
sort_points = jnp.take_along_axis(traced_points, sdx, axis=0)
180+
t = jnp.stack([mesh_weight_map, mesh_weight_map], axis=1)
181+
t = jnp.take_along_axis(t, sdx, axis=0)
182+
t = jnp.cumsum(t, axis=0)
183+
184+
# The CDF estimation needs to be a smooth function to avoid noise caused by
185+
# using a sub-set of traced points
186+
#
187+
# A polynomial is fit to the *inverse* CDF, this polynomial is inverted
188+
# numerically to get the smooth CDF function.
189+
#
190+
# The polynomial is fit to 'y' points at the Chebyshev nodes to avoid the
191+
# Runge phenomenon and to estimate the gradient of the CDF
192+
#
193+
# The gradient of the CDF is use as the weights for the polynomial fit
194+
# (e.g. where the CDF changes rapidly the weight is higher). This
195+
# helps prevent overfitting for high degree polynomials and helps keep
196+
# log degree polynomials monotonic.
197+
198+
# Use 3x more Chebyshev nodes than the degree being fit
199+
cheb_deg = 3 * deg
200+
# calculate nodes and interpolated values at the nodes
201+
cheb_nodes = ((jnp.cos((2 * jnp.arange(cheb_deg) + 1) * jnp.pi / (2 * cheb_deg))[::-1]) + 1) / 2
202+
cy = jnp.stack([cheb_nodes, cheb_nodes], axis=1)
203+
cx = jax.vmap(interp1d, in_axes=(None, 1, 1), out_axes=1)(cheb_nodes, t, sort_points)
204+
205+
# fit the polynomial with weights
206+
w = v_gradient(cy, cx)
207+
coefs = v_polyfit(cy, cx, deg, None, False, w)
208+
209+
# invert the polynomial with custom class
210+
inv_poly = InvertPolySpline(coefs, sort_points[0], sort_points[-1], low_res=20 * deg)
211+
# return sort_points and t for plotting below, in production we just need the transforms
212+
return inv_poly, sort_points, t
213+
214+
ips, sort_points, t = create_transforms_spline(source_grid_scaled, deg=deg, mesh_weight_map=mesh_weight_map)
215+
216+
transform = jax.jit(ips.fwd_transform)
217+
inv_transform = jax.jit(ips.rev_transform)
218+
219+
return transform, inv_transform

0 commit comments

Comments
 (0)