Skip to content

Commit 7e2468c

Browse files
author
Niek Wielders
committed
zeta_from made jax compatible, zeta_from and wofz in higher precision, remove import of scipy wofz
1 parent dd480c3 commit 7e2468c

File tree

1 file changed

+41
-62
lines changed

1 file changed

+41
-62
lines changed

autogalaxy/profiles/mass/stellar/gaussian.py

Lines changed: 41 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -189,45 +189,31 @@ def axis_ratio(self, xp=np):
189189
return xp.where(axis_ratio < 0.9999, axis_ratio, 0.9999)
190190

191191
def zeta_from(self, grid: aa.type.Grid2DLike, xp=np):
192-
q = self.axis_ratio(xp)
193-
q2 = q**2.0
192+
q = xp.asarray(self.axis_ratio(xp), dtype=xp.float64)
193+
q2 = q * q
194194

195-
ind_pos_y = grid.array[:, 0] >= 0
196-
shape_grid = np.shape(grid)
197-
output_grid = np.zeros((shape_grid[0]), dtype=np.complex128)
195+
y = xp.asarray(grid.array[:, 0], dtype=xp.float64)
196+
x = xp.asarray(grid.array[:, 1], dtype=xp.float64)
198197

199-
scale_factor = q / (self.sigma * xp.sqrt(2.0 * (1.0 - q2)))
198+
ind_pos_y = y >= 0
200199

201-
xs_0 = grid.array[:, 1][ind_pos_y] * scale_factor
202-
ys_0 = grid.array[:, 0][ind_pos_y] * scale_factor
203-
xs_1 = grid.array[:, 1][~ind_pos_y] * scale_factor
204-
ys_1 = -grid.array[:, 0][~ind_pos_y] * scale_factor
200+
scale = q / (xp.asarray(self.sigma, dtype=xp.float64)
201+
* xp.sqrt(xp.asarray(2.0, dtype=xp.float64) * (1.0 - q2)))
205202

206-
z1_0 = xs_0 + 1j * ys_0
207-
z2_0 = q * xs_0 + 1j * ys_0 / q
208-
z1_1 = xs_1 + 1j * ys_1
209-
z2_1 = q * xs_1 + 1j * ys_1 / q
203+
xs = x * scale
204+
ys = y * scale
210205

211-
exp_term_0 = xp.exp(-(xs_0**2) * (1.0 - q2) - ys_0**2 * (1.0 / q2 - 1.0))
212-
exp_term_1 = xp.exp(-(xs_1**2) * (1.0 - q2) - ys_1**2 * (1.0 / q2 - 1.0))
206+
z1 = xs + 1j * ys
207+
z2 = q * xs + 1j * ys / q
213208

214-
if xp == np:
215-
from scipy.special import wofz
216-
217-
output_grid[ind_pos_y] = -1j * (wofz(z1_0) - exp_term_0 * wofz(z2_0))
218-
output_grid[~ind_pos_y] = xp.conj(
219-
-1j * (wofz(z1_1) - exp_term_1 * wofz(z2_1))
220-
)
209+
exp_term = xp.exp(
210+
-(xs * xs) * (1.0 - q2)
211+
- (ys * ys) * (1.0 / q2 - 1.0)
212+
)
221213

222-
else:
223-
output_grid[ind_pos_y] = -1j * (
224-
self.wofz(z1_0, xp=xp) - exp_term_0 * self.wofz(z2_0, xp=xp)
225-
)
226-
output_grid[~ind_pos_y] = xp.conj(
227-
-1j * (self.wofz(z1_1, xp=xp) - exp_term_1 * self.wofz(z2_1, xp=xp))
228-
)
214+
core = -1j * (self.wofz(z1, xp=xp) - exp_term * self.wofz(z2, xp=xp))
229215

230-
return output_grid
216+
return xp.where(ind_pos_y, core, xp.conj(core))
231217

232218
def wofz(self, z, xp=np):
233219
"""
@@ -236,71 +222,64 @@ def wofz(self, z, xp=np):
236222
Valid for all complex z. JIT + autodiff safe.
237223
"""
238224

239-
z = xp.asarray(z)
225+
z = xp.asarray(z, dtype=xp.complex128)
240226
x = xp.real(z)
241227
y = xp.imag(z)
242228

243229
r2 = x * x + y * y
244230
y2 = y * y
245231
z2 = z * z
246-
sqrt_pi = xp.sqrt(xp.pi)
247232

248-
# --- Regions 1 to 4 ---
249-
r1_s1 = xp.array([2.5, 2.0, 1.5, 1.0, 0.5])
233+
sqrt_pi = xp.asarray(xp.sqrt(xp.pi), dtype=xp.float64)
234+
inv_sqrt_pi = xp.asarray(1.0 / sqrt_pi, dtype=xp.float64)
235+
236+
# ---------- Large-|z| continued fraction ----------
237+
r1_s1 = xp.asarray([2.5, 2.0, 1.5, 1.0, 0.5], dtype=xp.float64)
250238

251239
t = z
252-
for coef in r1_s1:
253-
t = z - coef / t
240+
for c in r1_s1:
241+
t = z - c / t
254242

255-
w_large = 1j / (t * sqrt_pi)
243+
w_large = 1j * inv_sqrt_pi / t
256244

257-
# --- Region 5: special small-imaginary case ---
258-
U5 = xp.array([1.320522, 35.7668, 219.031, 1540.787, 3321.990, 36183.31])
259-
V5 = xp.array(
260-
[1.841439, 61.57037, 364.2191, 2186.181, 9022.228, 24322.84, 32066.6]
261-
)
245+
# ---------- Region 5 ----------
246+
U5 = xp.asarray([1.320522, 35.7668, 219.031,
247+
1540.787, 3321.990, 36183.31], dtype=xp.float64)
248+
V5 = xp.asarray([1.841439, 61.57037, 364.2191,
249+
2186.181, 9022.228, 24322.84, 32066.6], dtype=xp.float64)
262250

263-
t = 1 / sqrt_pi
251+
t = inv_sqrt_pi
264252
for u in U5:
265253
t = u + z2 * t
266254

267-
s = 1.0
255+
s = xp.asarray(1.0, dtype=xp.float64)
268256
for v in V5:
269257
s = v + z2 * s
270258

271259
w5 = xp.exp(-z2) + 1j * z * t / s
272260

273-
# --- Region 6: remaining small-|z| region ---
274-
U6 = xp.array([5.9126262, 30.180142, 93.15558, 181.92853, 214.38239, 122.60793])
275-
V6 = xp.array(
276-
[
277-
10.479857,
278-
53.992907,
279-
170.35400,
280-
348.70392,
281-
457.33448,
282-
352.73063,
283-
122.60793,
284-
]
285-
)
261+
# ---------- Region 6 ----------
262+
U6 = xp.asarray([5.9126262, 30.180142, 93.15558,
263+
181.92853, 214.38239, 122.60793], dtype=xp.float64)
264+
V6 = xp.asarray([10.479857, 53.992907, 170.35400,
265+
348.70392, 457.33448, 352.73063, 122.60793], dtype=xp.float64)
286266

287-
t = 1 / sqrt_pi
267+
t = inv_sqrt_pi
288268
for u in U6:
289269
t = u - 1j * z * t
290270

291-
s = 1.0
271+
s = xp.asarray(1.0, dtype=xp.float64)
292272
for v in V6:
293273
s = v - 1j * z * s
294274

295275
w6 = t / s
296276

297-
# --- Regions ---
277+
# ---------- Region logic ----------
298278
reg1 = (r2 >= 62.0) | ((r2 >= 30.0) & (r2 < 62.0) & (y2 >= 1e-13))
299279
reg2 = ((r2 >= 30) & (r2 < 62) & (y2 < 1e-13)) | (
300280
(r2 >= 2.5) & (r2 < 30) & (y2 < 0.072)
301281
)
302282

303-
# --- Combine regions using pure array logic ---
304283
w = w6
305284
w = xp.where(reg2, w5, w)
306285
w = xp.where(reg1, w_large, w)

0 commit comments

Comments
 (0)