1+ from functools import partial
2+ import numpy as np
3+ from typing import Tuple
4+
5+
6+ def create_transforms (source_grid_scaled , deg = 11 , mesh_weight_map = None ):
7+
8+ import jax
9+ import jax .numpy as jnp
10+ from jax .tree_util import register_pytree_node_class
11+
12+ @jax .jit
13+ def interp1d (x , xp , fp ): #, left=None, right=None):
14+ i = jnp .clip (
15+ jnp .searchsorted (
16+ xp ,
17+ x ,
18+ side = 'right' ,
19+ method = 'scan_unrolled'
20+ ),
21+ 1 ,
22+ len (xp ) - 1
23+ )
24+ df = fp [i ] - fp [i - 1 ]
25+ dx = xp [i ] - xp [i - 1 ]
26+ delta = x - xp [i - 1 ]
27+ eps = jnp .finfo (xp .dtype ).eps
28+ epsilon = jnp .nextafter (eps , jnp .inf ) - eps
29+
30+ dx0 = jax .lax .abs (dx ) <= epsilon # Prevent NaN gradients when `dx` is small.
31+ f = jnp .where (
32+ dx0 ,
33+ fp [i - 1 ],
34+ fp [i - 1 ] + (delta / jnp .where (dx0 , 1 , dx )) * df
35+ )
36+
37+ return f
38+
39+ @jax .custom_jvp
40+ def spline_invert (ip , x ):
41+ # use a custom jvp because we are using cached values to get the spline faster
42+ # and this would not easily give the grad with respect to the poly coefs as written
43+ k_right = jnp .digitize (x , ip .x_low_res , method = 'scan_unrolled' )
44+ k_left = k_right - 1
45+
46+ # jax's default out-of-bound index gives
47+ # correct result for point on the right most
48+ # edge of interpolation, no need to do anything
49+ # special for the boundary
50+ t = (x - ip .x_low_res [k_left ]) / ip .delta_x [k_left ]
51+ t2 = t ** 2
52+ t3 = t ** 3
53+ h00 = 2 * t3 - 3 * t2 + 1
54+ h10 = t3 - 2 * t2 + t
55+ h01 = - 2 * t3 + 3 * t2
56+ h11 = t3 - t2
57+ return ip .y_low_res [k_left ] * h00 + ip .y_low_res [k_right ] * h01 + (
58+ ip .dy_low_res [k_left ] * h10 + ip .dy_low_res [k_right ] * h11 ) * ip .delta_x [k_left ]
59+
60+ @spline_invert .defjvp
61+ def invert_poly_jvp (primals , tangents ):
62+ # because this is the inverse of a polynomial it's
63+ # gradient can be written in terms of the gradient
64+ # of the polynomial evaluated at the output
65+ # this is easy to write down and avoids needing
66+ # to grad through the cubic spline inversion
67+ ip , x = primals
68+ ip_dot , x_dot = tangents
69+ primal_out = spline_invert (ip , x )
70+ d_dx = 1 / jnp .polyval (ip .dcoefs , primal_out )
71+ d_dcoefs = - jnp .vander (jnp .atleast_1d (primal_out ), N = ip .coefs .shape [0 ])
72+ tangent_out = ((ip_dot .coefs * d_dcoefs ).sum () + x_dot ) * d_dx
73+ return primal_out , tangent_out
74+
75+ @register_pytree_node_class
76+ class InvertPolySpline :
77+ @staticmethod
78+ def v_polyder (c ):
79+ return jax .vmap (
80+ jnp .polyder ,
81+ in_axes = 1 ,
82+ out_axes = 1
83+ )(c )
84+
85+ @staticmethod
86+ def v_polyval (c , x ):
87+ return jax .vmap (
88+ jnp .polyval ,
89+ in_axes = (1 , 1 ),
90+ out_axes = (1 )
91+ )(c , x )
92+
93+ def __init__ (self , coefs , lower_bound , upper_bound , low_res = 150 ):
94+ # coefs Nx2
95+ # lower_bound 1x2
96+ # upper_bound 1x2
97+ # low_res int
98+
99+ # polynomial to inverse
100+ self .coefs = coefs
101+
102+ # get 1st derivative of polynomial
103+ self .dcoefs = InvertPolySpline .v_polyder (self .coefs )
104+
105+ # The bounds of the CDF
106+ # below will always be 0
107+ # above will always be 1
108+ self .lower_bound = jnp .atleast_2d (lower_bound )
109+ self .upper_bound = jnp .atleast_2d (upper_bound )
110+
111+ # low resolution grid of nodes for spline approx to the inverse function
112+ # cubic spline needs the function, derivative, and delta_x at each node
113+ self .low_res = low_res
114+ y_low_res = jnp .linspace (0 , 1 , low_res )
115+ self .y_low_res = jnp .stack ([y_low_res , y_low_res ], axis = 1 )
116+ self .x_low_res = InvertPolySpline .v_polyval (self .coefs , self .y_low_res )
117+ self .dy_low_res = 1 / InvertPolySpline .v_polyval (self .dcoefs , self .y_low_res )
118+ self .delta_x = jnp .diff (self .x_low_res , axis = 0 )
119+
120+ def __repr__ (self ):
121+ return f'InvertPoly(coefs={ self .coefs } , lower_bound={ self .lower_bound } , upper_bound={ self .upper_bound } )'
122+
123+ def tree_flatten (self ):
124+ children = (
125+ self .coefs ,
126+ self .dcoefs ,
127+ self .y_low_res ,
128+ self .x_low_res ,
129+ self .dy_low_res ,
130+ self .delta_x ,
131+ self .lower_bound ,
132+ self .upper_bound
133+ )
134+ aux_data = (self .low_res ,)
135+ return (children , aux_data )
136+
137+ @classmethod
138+ def tree_unflatten (cls , aux_data , children ):
139+ # return cls(*(children + aux_data))
140+ obj = object .__new__ (InvertPolySpline )
141+ obj .coefs = children [0 ]
142+ obj .dcoefs = children [1 ]
143+ obj .y_low_res = children [2 ]
144+ obj .x_low_res = children [3 ]
145+ obj .dy_low_res = children [4 ]
146+ obj .delta_x = children [5 ]
147+ obj .lower_bound = children [6 ]
148+ obj .upper_bound = children [7 ]
149+ obj .low_res = aux_data [0 ]
150+ return obj
151+
152+ def fwd_transform (self , x ):
153+ y = jax .vmap (
154+ spline_invert ,
155+ in_axes = (1 , 1 ),
156+ out_axes = (1 )
157+ )(self , x )
158+ y = jnp .where (x <= self .lower_bound , 0.0 , y )
159+ y = jnp .where (x >= self .upper_bound , 1.0 , y )
160+ return jnp .clip (y , 0.0 , 1.0 )
161+
162+ def rev_transform (self , y ):
163+ return InvertPolySpline .v_polyval (self .coefs , y )
164+
165+ v_polyfit = jax .vmap (jnp .polyfit , in_axes = (1 , 1 , None , None , None , 1 ), out_axes = (1 ))
166+ v_gradient = jax .vmap (jnp .gradient , in_axes = (1 , 1 ), out_axes = 1 )
167+
168+ # inv_poly is a pytree, it can be returned from `jit` without issue :D
169+ @partial (jax .jit , static_argnames = ('deg' ))
170+ def create_transforms_spline (traced_points , deg = 11 , mesh_weight_map = None ):
171+
172+ N = traced_points .shape [0 ] # // 2
173+ if mesh_weight_map is None :
174+ t = jnp .arange (1 , N + 1 ) / (N + 1 )
175+ t = jnp .stack ([t , t ], axis = 1 )
176+ sort_points = jnp .sort (traced_points , axis = 0 ) # [::2]
177+ else :
178+ sdx = jnp .argsort (traced_points , axis = 0 )
179+ sort_points = jnp .take_along_axis (traced_points , sdx , axis = 0 )
180+ t = jnp .stack ([mesh_weight_map , mesh_weight_map ], axis = 1 )
181+ t = jnp .take_along_axis (t , sdx , axis = 0 )
182+ t = jnp .cumsum (t , axis = 0 )
183+
184+ # The CDF estimation needs to be a smooth function to avoid noise caused by
185+ # using a sub-set of traced points
186+ #
187+ # A polynomial is fit to the *inverse* CDF, this polynomial is inverted
188+ # numerically to get the smooth CDF function.
189+ #
190+ # The polynomial is fit to 'y' points at the Chebyshev nodes to avoid the
191+ # Runge phenomenon and to estimate the gradient of the CDF
192+ #
193+ # The gradient of the CDF is use as the weights for the polynomial fit
194+ # (e.g. where the CDF changes rapidly the weight is higher). This
195+ # helps prevent overfitting for high degree polynomials and helps keep
196+ # log degree polynomials monotonic.
197+
198+ # Use 3x more Chebyshev nodes than the degree being fit
199+ cheb_deg = 3 * deg
200+ # calculate nodes and interpolated values at the nodes
201+ cheb_nodes = ((jnp .cos ((2 * jnp .arange (cheb_deg ) + 1 ) * jnp .pi / (2 * cheb_deg ))[::- 1 ]) + 1 ) / 2
202+ cy = jnp .stack ([cheb_nodes , cheb_nodes ], axis = 1 )
203+ cx = jax .vmap (interp1d , in_axes = (None , 1 , 1 ), out_axes = 1 )(cheb_nodes , t , sort_points )
204+
205+ # fit the polynomial with weights
206+ w = v_gradient (cy , cx )
207+ coefs = v_polyfit (cy , cx , deg , None , False , w )
208+
209+ # invert the polynomial with custom class
210+ inv_poly = InvertPolySpline (coefs , sort_points [0 ], sort_points [- 1 ], low_res = 20 * deg )
211+ # return sort_points and t for plotting below, in production we just need the transforms
212+ return inv_poly , sort_points , t
213+
214+ ips , sort_points , t = create_transforms_spline (source_grid_scaled , deg = deg , mesh_weight_map = mesh_weight_map )
215+
216+ transform = jax .jit (ips .fwd_transform )
217+ inv_transform = jax .jit (ips .rev_transform )
218+
219+ return transform , inv_transform
0 commit comments