@@ -189,45 +189,31 @@ def axis_ratio(self, xp=np):
189189 return xp .where (axis_ratio < 0.9999 , axis_ratio , 0.9999 )
190190
191191 def zeta_from (self , grid : aa .type .Grid2DLike , xp = np ):
192- q = self .axis_ratio (xp )
193- q2 = q ** 2.0
192+ q = xp . asarray ( self .axis_ratio (xp ), dtype = xp . float64 )
193+ q2 = q * q
194194
195- ind_pos_y = grid .array [:, 0 ] >= 0
196- shape_grid = np .shape (grid )
197- output_grid = np .zeros ((shape_grid [0 ]), dtype = np .complex128 )
195+ y = xp .asarray (grid .array [:, 0 ], dtype = xp .float64 )
196+ x = xp .asarray (grid .array [:, 1 ], dtype = xp .float64 )
198197
199- scale_factor = q / ( self . sigma * xp . sqrt ( 2.0 * ( 1.0 - q2 )))
198+ ind_pos_y = y >= 0
200199
201- xs_0 = grid .array [:, 1 ][ind_pos_y ] * scale_factor
202- ys_0 = grid .array [:, 0 ][ind_pos_y ] * scale_factor
203- xs_1 = grid .array [:, 1 ][~ ind_pos_y ] * scale_factor
204- ys_1 = - grid .array [:, 0 ][~ ind_pos_y ] * scale_factor
200+ scale = q / (xp .asarray (self .sigma , dtype = xp .float64 )
201+ * xp .sqrt (xp .asarray (2.0 , dtype = xp .float64 ) * (1.0 - q2 )))
205202
206- z1_0 = xs_0 + 1j * ys_0
207- z2_0 = q * xs_0 + 1j * ys_0 / q
208- z1_1 = xs_1 + 1j * ys_1
209- z2_1 = q * xs_1 + 1j * ys_1 / q
203+ xs = x * scale
204+ ys = y * scale
210205
211- exp_term_0 = xp . exp ( - ( xs_0 ** 2 ) * ( 1.0 - q2 ) - ys_0 ** 2 * ( 1.0 / q2 - 1.0 ))
212- exp_term_1 = xp . exp ( - ( xs_1 ** 2 ) * ( 1.0 - q2 ) - ys_1 ** 2 * ( 1.0 / q2 - 1.0 ))
206+ z1 = xs + 1j * ys
207+ z2 = q * xs + 1j * ys / q
213208
214- if xp == np :
215- from scipy .special import wofz
216-
217- output_grid [ind_pos_y ] = - 1j * (wofz (z1_0 ) - exp_term_0 * wofz (z2_0 ))
218- output_grid [~ ind_pos_y ] = xp .conj (
219- - 1j * (wofz (z1_1 ) - exp_term_1 * wofz (z2_1 ))
220- )
209+ exp_term = xp .exp (
210+ - (xs * xs ) * (1.0 - q2 )
211+ - (ys * ys ) * (1.0 / q2 - 1.0 )
212+ )
221213
222- else :
223- output_grid [ind_pos_y ] = - 1j * (
224- self .wofz (z1_0 , xp = xp ) - exp_term_0 * self .wofz (z2_0 , xp = xp )
225- )
226- output_grid [~ ind_pos_y ] = xp .conj (
227- - 1j * (self .wofz (z1_1 , xp = xp ) - exp_term_1 * self .wofz (z2_1 , xp = xp ))
228- )
214+ core = - 1j * (self .wofz (z1 , xp = xp ) - exp_term * self .wofz (z2 , xp = xp ))
229215
230- return output_grid
216+ return xp . where ( ind_pos_y , core , xp . conj ( core ))
231217
232218 def wofz (self , z , xp = np ):
233219 """
@@ -236,71 +222,64 @@ def wofz(self, z, xp=np):
236222 Valid for all complex z. JIT + autodiff safe.
237223 """
238224
239- z = xp .asarray (z )
225+ z = xp .asarray (z , dtype = xp . complex128 )
240226 x = xp .real (z )
241227 y = xp .imag (z )
242228
243229 r2 = x * x + y * y
244230 y2 = y * y
245231 z2 = z * z
246- sqrt_pi = xp .sqrt (xp .pi )
247232
248- # --- Regions 1 to 4 ---
249- r1_s1 = xp .array ([2.5 , 2.0 , 1.5 , 1.0 , 0.5 ])
233+ sqrt_pi = xp .asarray (xp .sqrt (xp .pi ), dtype = xp .float64 )
234+ inv_sqrt_pi = xp .asarray (1.0 / sqrt_pi , dtype = xp .float64 )
235+
236+ # ---------- Large-|z| continued fraction ----------
237+ r1_s1 = xp .asarray ([2.5 , 2.0 , 1.5 , 1.0 , 0.5 ], dtype = xp .float64 )
250238
251239 t = z
252- for coef in r1_s1 :
253- t = z - coef / t
240+ for c in r1_s1 :
241+ t = z - c / t
254242
255- w_large = 1j / ( t * sqrt_pi )
243+ w_large = 1j * inv_sqrt_pi / t
256244
257- # --- Region 5: special small-imaginary case ---
258- U5 = xp .array ([1.320522 , 35.7668 , 219.031 , 1540.787 , 3321.990 , 36183.31 ])
259- V5 = xp .array (
260- [1.841439 , 61.57037 , 364.2191 , 2186.181 , 9022.228 , 24322.84 , 32066.6 ]
261- )
245+ # ---------- Region 5 ------- ---
246+ U5 = xp .asarray ([1.320522 , 35.7668 , 219.031 ,
247+ 1540.787 , 3321.990 , 36183.31 ], dtype = xp .float64 )
248+ V5 = xp . asarray ( [1.841439 , 61.57037 , 364.2191 ,
249+ 2186.181 , 9022.228 , 24322.84 , 32066.6 ], dtype = xp . float64 )
262250
263- t = 1 / sqrt_pi
251+ t = inv_sqrt_pi
264252 for u in U5 :
265253 t = u + z2 * t
266254
267- s = 1.0
255+ s = xp . asarray ( 1.0 , dtype = xp . float64 )
268256 for v in V5 :
269257 s = v + z2 * s
270258
271259 w5 = xp .exp (- z2 ) + 1j * z * t / s
272260
273- # --- Region 6: remaining small-|z| region ---
274- U6 = xp .array ([5.9126262 , 30.180142 , 93.15558 , 181.92853 , 214.38239 , 122.60793 ])
275- V6 = xp .array (
276- [
277- 10.479857 ,
278- 53.992907 ,
279- 170.35400 ,
280- 348.70392 ,
281- 457.33448 ,
282- 352.73063 ,
283- 122.60793 ,
284- ]
285- )
261+ # ---------- Region 6 ----------
262+ U6 = xp .asarray ([5.9126262 , 30.180142 , 93.15558 ,
263+ 181.92853 , 214.38239 , 122.60793 ], dtype = xp .float64 )
264+ V6 = xp .asarray ([10.479857 , 53.992907 , 170.35400 ,
265+ 348.70392 , 457.33448 , 352.73063 , 122.60793 ], dtype = xp .float64 )
286266
287- t = 1 / sqrt_pi
267+ t = inv_sqrt_pi
288268 for u in U6 :
289269 t = u - 1j * z * t
290270
291- s = 1.0
271+ s = xp . asarray ( 1.0 , dtype = xp . float64 )
292272 for v in V6 :
293273 s = v - 1j * z * s
294274
295275 w6 = t / s
296276
297- # --- Regions ---
277+ # ---------- Region logic ------- ---
298278 reg1 = (r2 >= 62.0 ) | ((r2 >= 30.0 ) & (r2 < 62.0 ) & (y2 >= 1e-13 ))
299279 reg2 = ((r2 >= 30 ) & (r2 < 62 ) & (y2 < 1e-13 )) | (
300280 (r2 >= 2.5 ) & (r2 < 30 ) & (y2 < 0.072 )
301281 )
302282
303- # --- Combine regions using pure array logic ---
304283 w = w6
305284 w = xp .where (reg2 , w5 , w )
306285 w = xp .where (reg1 , w_large , w )
0 commit comments