11from __future__ import annotations
2+ import jax .numpy as jnp
3+ import jax .scipy .special as jsp
24import numpy as np
3-
45import math
5- import scipy .special as sc
66from typing import TYPE_CHECKING
77
88if TYPE_CHECKING :
1212
1313from autoarray import numba_util
1414
15+ import jax .numpy as jnp
16+
17+
18+ def kv_xp (v , z , xp = np ):
19+ """
20+ XP-compatible modified Bessel K_v(v, z).
21+
22+ NumPy backend:
23+ -> scipy.special.kv
1524
16- @numba_util .jit (cache = False )
17- def matern_kernel (r : float , l : float = 1.0 , v : float = 0.5 ):
25+ JAX backend:
26+ -> jax.scipy.special.kv if available
27+ -> else tfp.substrates.jax.math.bessel_kve * exp(-|z|)
1828 """
19- need to `pip install numba-scipy `
20- see https://gaussianprocess.org/gpml/chapters/RW4.pdf for more info
2129
22- the distance r need to be scalar
23- l is the scale
24- v is the order, better < 30, otherwise may have numerical NaN issue.
30+ # -------------------------
31+ # NumPy backend
32+ # -------------------------
33+ if xp is np :
34+ import scipy .special as sc
35+
36+ return sc .kv (v , z )
2537
26- v control the smoothness level. the larger the v, the stronger smoothing condition (i.e., the solution is
27- v-th differentiable) imposed by the kernel.
38+ # -------------------------
39+ # JAX backend
40+ # -------------------------
41+ else :
42+ try :
43+ import tensorflow_probability .substrates .jax as tfp
44+
45+ return tfp .math .bessel_kve (v , z ) * xp .exp (- xp .abs (z ))
46+ except ImportError :
47+ raise ImportError (
48+ "To use the JAX backend with the Matérn kernel, "
49+ "please install tensorflow-probability via `pip install tensorflow-probability==0.25.0`."
50+ )
51+
52+
53+ def matern_kernel (r , l : float = 1.0 , v : float = 0.5 , xp = np ):
2854 """
29- r = abs (r )
30- if r == 0 :
31- r = 0.00000001
32- part1 = 2 ** (1 - v ) / math .gamma (v )
33- part2 = (math .sqrt (2 * v ) * r / l ) ** v
34- part3 = sc .kv (v , math .sqrt (2 * v ) * r / l )
55+ XP-compatible Matérn kernel.
56+ Works with NumPy or JAX.
57+ """
58+
59+ # Avoid r = 0 singularity (JAX-safe)
60+ r = xp .maximum (xp .abs (r ), 1e-8 )
61+
62+ z = xp .sqrt (2.0 * v ) * r / l
63+
64+ part1 = 2.0 ** (1.0 - v ) / math .gamma (v ) # scalar constant
65+ part2 = z ** v
66+ part3 = kv_xp (v , z , xp )
67+
3568 return part1 * part2 * part3
3669
3770
38- @numba_util .jit (cache = False )
3971def matern_cov_matrix_from (
4072 scale : float ,
4173 nu : float ,
42- pixel_points : np .ndarray ,
43- ) -> np .ndarray :
74+ pixel_points ,
75+ xp = np ,
76+ ):
4477 """
4578 Consutruct the regularization covariance matrix, which is used to determined the regularization pattern (i.e,
4679 how the different pixels are correlated).
@@ -63,45 +96,27 @@ def matern_cov_matrix_from(
6396 The source covariance matrix (2d array), shape [N_source_pixels, N_source_pixels].
6497 """
6598
66- pixels = len (pixel_points )
67- covariance_matrix = np .zeros (shape = (pixels , pixels ))
68-
69- for i in range (pixels ):
70- covariance_matrix [i , i ] += 1e-8
71- for j in range (pixels ):
72- xi = pixel_points [i , 1 ]
73- yi = pixel_points [i , 0 ]
74- xj = pixel_points [j , 1 ]
75- yj = pixel_points [j , 0 ]
76- d_ij = np .sqrt (
77- (xi - xj ) ** 2 + (yi - yj ) ** 2
78- ) # distance between the pixel i and j
79-
80- covariance_matrix [i , j ] += matern_kernel (d_ij , l = scale , v = nu )
81-
82- return covariance_matrix
99+ # --------------------------------
100+ # Pairwise distances (broadcasted)
101+ # --------------------------------
102+ # pixel_points[:, None, :] -> (N, 1, 2)
103+ # pixel_points[None, :, :] -> (1, N, 2)
104+ diff = pixel_points [:, None , :] - pixel_points [None , :, :] # (N, N, 2)
83105
106+ d_ij = xp .sqrt (diff [..., 0 ] ** 2 + diff [..., 1 ] ** 2 ) # (N, N)
84107
85- class NumbaScipyPlaceholder :
86- pass
108+ # --------------------------------
109+ # Apply Matérn kernel elementwise
110+ # --------------------------------
111+ covariance_matrix = matern_kernel (d_ij , l = scale , v = nu , xp = xp )
87112
113+ # --------------------------------
114+ # Add diagonal jitter (JAX-safe)
115+ # --------------------------------
116+ pixels = pixel_points .shape [0 ]
117+ covariance_matrix = covariance_matrix + 1e-8 * xp .eye (pixels )
88118
89- try :
90- import numba_scipy
91-
92- numba_scipy = object
93- except ModuleNotFoundError :
94- numba_scipy = NumbaScipyPlaceholder ()
95-
96-
97- def numba_scipy_exception ():
98- raise ModuleNotFoundError (
99- "\n --------------------\n "
100- "You are attempting to use the MaternKernel for Regularization.\n \n "
101- "However, the optional library numba_scipy (https://pypi.org/project/numba-scipy/) is not installed.\n \n "
102- "Install it via the command `pip install numba-scipy==0.3.1`.\n \n "
103- "----------------------"
104- )
119+ return covariance_matrix
105120
106121
107122class MaternKernel (AbstractRegularization ):
@@ -131,9 +146,6 @@ def __init__(self, coefficient: float = 1.0, scale: float = 1.0, nu: float = 0.5
131146 Controls the derivative of the regularization pattern (`nu=0.5` is a Gaussian).
132147 """
133148
134- if isinstance (numba_scipy , NumbaScipyPlaceholder ):
135- numba_scipy_exception ()
136-
137149 self .coefficient = coefficient
138150 self .scale = float (scale )
139151 self .nu = float (nu )
@@ -175,8 +187,9 @@ def regularization_matrix_from(self, linear_obj: LinearObj, xp=np) -> np.ndarray
175187 """
176188 covariance_matrix = matern_cov_matrix_from (
177189 scale = self .scale ,
178- pixel_points = xp . array ( linear_obj .source_plane_mesh_grid ) ,
190+ pixel_points = linear_obj .source_plane_mesh_grid . array ,
179191 nu = self .nu ,
192+ xp = xp ,
180193 )
181194
182195 return self .coefficient * xp .linalg .inv (covariance_matrix )
0 commit comments