@@ -389,8 +389,8 @@ def compute_shear(self, u0, nfilter=(1., 2.)):
389389 return
390390
391391 ny , nx = gc ['z' ].shape
392- kx , ky = np .meshgrid (2. * np .pi * np .fft .fftfreq (nx + 1 , gc ['dx' ])[ 1 :] ,
393- 2. * np .pi * np .fft .fftfreq (ny + 1 , gc ['dy' ])[ 1 :] )
392+ kx , ky = np .meshgrid (2. * np .pi * np .fft .fftfreq (nx , gc ['dx' ]),
393+ 2. * np .pi * np .fft .fftfreq (ny , gc ['dy' ]))
394394
395395 hs = np .fft .fft2 (gc ['z' ])
396396 hs = self .filter_highfrequenies (kx , ky , hs , nfilter )
@@ -421,25 +421,58 @@ def compute_shear(self, u0, nfilter=(1., 2.)):
421421
422422 # Arrays in Fourier
423423 k = np .sqrt (kx ** 2 + ky ** 2 )
424- sigma = np .sqrt (1j * L * kx * z0new / l )
425424
426425
427426 time_start_perturbation = time .time ()
428-
429- # Shear stress perturbation
430-
431- dtaux_t = hs * kx ** 2 / k * 2 / ul ** 2 * \
432- (- 1. + (2. * np .log (l / z0new ) + k ** 2 / kx ** 2 ) * sigma * \
433- sc_kv (1. , 2. * sigma ) / sc_kv (0. , 2. * sigma ))
434427
435-
436- dtauy_t = hs * kx * ky / k * 2 / ul ** 2 * \
437- 2. * np .sqrt (2. ) * sigma * sc_kv (1. , 2. * np .sqrt (2. ) * sigma )
438428
439-
429+ # Shear stress perturbation
430+ # Use masked computation to avoid division by zero and invalid special-function calls.
431+ # Build boolean mask for valid Fourier modes where formula is defined.
432+ valid = (k != 0 ) & (kx != 0 )
433+
434+ # Pre-allocate zero arrays for Fourier-domain shear perturbations
435+ dtaux_t = np .zeros_like (hs , dtype = complex )
436+ dtauy_t = np .zeros_like (hs , dtype = complex )
437+
438+ if np .any (valid ):
439+ # Extract valid-mode arrays
440+ k_v = k [valid ]
441+ kx_v = kx [valid ]
442+ ky_v = ky [valid ]
443+ hs_v = hs [valid ]
444+
445+ # z0new can be scalar or array; index accordingly
446+ if np .size (z0new ) == 1 :
447+ z0_v = z0new
448+ else :
449+ z0_v = z0new [valid ]
450+
451+ # compute sigma on valid modes
452+ sigma_v = np .sqrt (1j * L * kx_v * z0_v / l )
453+
454+ # Evaluate Bessel K functions on valid arguments only
455+ kv0 = sc_kv (0. , 2. * sigma_v )
456+ kv1 = sc_kv (1. , 2. * sigma_v )
457+
458+ # main x-direction perturbation (vectorized on valid indices)
459+ term_x = - 1. + (2. * np .log (l / z0_v ) + (k_v ** 2 ) / (kx_v ** 2 )) * sigma_v * (kv1 / kv0 )
460+ dtaux_v = hs_v * (kx_v ** 2 ) / k_v * 2. / ul ** 2 * term_x
461+
462+ # y-direction perturbation (also vectorized)
463+ kv1_y = sc_kv (1. , 2. * np .sqrt (2. ) * sigma_v )
464+ dtauy_v = hs_v * (kx_v * ky_v ) / k_v * 2. / ul ** 2 * 2. * np .sqrt (2. ) * sigma_v * (kv1_y )
465+
466+ # store back into full arrays (other entries remain zero)
467+ dtaux_t [valid ] = dtaux_v
468+ dtauy_t [valid ] = dtauy_v
469+
470+ # invalid modes remain 0 (physically reasonable for k=0 or kx=0)
440471 gc ['dtaux' ] = np .real (np .fft .ifft2 (dtaux_t ))
441472 gc ['dtauy' ] = np .real (np .fft .ifft2 (dtauy_t ))
442473
474+
475+
443476
444477 def separation_shear (self , hsep ):
445478 '''Reduces the computed wind shear perturbation below the
@@ -513,8 +546,11 @@ def filter_highfrequenies(self, kx, ky, hs, nfilter=(1, 2)):
513546 if nfilter is not None :
514547 n1 = np .min (nfilter )
515548 n2 = np .max (nfilter )
516- px = 2 * np .pi / self .cgrid ['dx' ] / np .abs (kx )
517- py = 2 * np .pi / self .cgrid ['dy' ] / np .abs (ky )
549+ # Avoid division by zero at DC component (kx=0, ky=0)
550+ kx_safe = np .where (kx == 0 , 1.0 , kx )
551+ ky_safe = np .where (ky == 0 , 1.0 , ky )
552+ px = 2 * np .pi / self .cgrid ['dx' ] / np .abs (kx_safe )
553+ py = 2 * np .pi / self .cgrid ['dy' ] / np .abs (ky_safe )
518554 s1 = n1 / np .log (1. / .01 - 1. )
519555 s2 = - n2 / np .log (1. / .99 - 1. )
520556 f1 = 1. / (1. + np .exp (- (px + n1 - n2 ) / s1 ))
@@ -727,4 +763,4 @@ def interpolate(self, x, y, z, xi, yi, z0):
727763
728764
729765
730-
766+
0 commit comments