@@ -576,31 +576,58 @@ def compute_shear(self, u0, nfilter=(1., 2.)):
576576
577577 # Arrays in Fourier
578578 k = np .sqrt (kx ** 2 + ky ** 2 )
579- sigma = np .sqrt (1j * L * kx * z0new / l )
580579
581580
582581 time_start_perturbation = time .time ()
583-
582+
583+
584584 # Shear stress perturbation
585- # Use safe division to avoid zero/invalid values at kx=0 or k=0
586- k_safe = np .where (k == 0 , 1.0 , k )
587- kx_safe = np .where (kx == 0 , 1.0 , kx )
588-
589- dtaux_t = hs * kx ** 2 / k_safe * 2 / ul ** 2 * \
590- (- 1. + (2. * np .log (l / z0new ) + k ** 2 / kx_safe ** 2 ) * sigma * \
591- sc_kv (1. , 2. * sigma ) / sc_kv (0. , 2. * sigma ))
585+ # Use masked computation to avoid division by zero and invalid special-function calls.
586+ # Build boolean mask for valid Fourier modes where formula is defined.
587+ valid = (k != 0 ) & (kx != 0 )
588+
589+ # Pre-allocate zero arrays for Fourier-domain shear perturbations
590+ dtaux_t = np .zeros_like (hs , dtype = complex )
591+ dtauy_t = np .zeros_like (hs , dtype = complex )
592+
593+ if np .any (valid ):
594+ # Extract valid-mode arrays
595+ k_v = k [valid ]
596+ kx_v = kx [valid ]
597+ ky_v = ky [valid ]
598+ hs_v = hs [valid ]
599+
600+ # z0new can be scalar or array; index accordingly
601+ if np .size (z0new ) == 1 :
602+ z0_v = z0new
603+ else :
604+ z0_v = z0new [valid ]
592605
593- dtauy_t = hs * kx * ky / k_safe * 2 / ul ** 2 * \
594- 2. * np .sqrt (2. ) * sigma * sc_kv ( 1. , 2. * np . sqrt ( 2. ) * sigma )
606+ # compute sigma on valid modes
607+ sigma_v = np .sqrt (1j * L * kx_v * z0_v / l )
595608
596- # Zero out invalid regions (kx=0 or k=0) where formulation is not valid
597- invalid_mask = (k == 0 ) | (kx == 0 )
598- dtaux_t [invalid_mask ] = 0.
599- dtauy_t [invalid_mask ] = 0.
600-
609+ # Evaluate Bessel K functions on valid arguments only
610+ kv0 = sc_kv (0. , 2. * sigma_v )
611+ kv1 = sc_kv (1. , 2. * sigma_v )
612+
613+ # main x-direction perturbation (vectorized on valid indices)
614+ term_x = - 1. + (2. * np .log (l / z0_v ) + (k_v ** 2 ) / (kx_v ** 2 )) * sigma_v * (kv1 / kv0 )
615+ dtaux_v = hs_v * (kx_v ** 2 ) / k_v * 2. / ul ** 2 * term_x
616+
617+ # y-direction perturbation (also vectorized)
618+ kv1_y = sc_kv (1. , 2. * np .sqrt (2. ) * sigma_v )
619+ dtauy_v = hs_v * (kx_v * ky_v ) / k_v * 2. / ul ** 2 * 2. * np .sqrt (2. ) * sigma_v * (kv1_y )
620+
621+ # store back into full arrays (other entries remain zero)
622+ dtaux_t [valid ] = dtaux_v
623+ dtauy_t [valid ] = dtauy_v
624+
625+ # invalid modes remain 0 (physically reasonable for k=0 or kx=0)
601626 gc ['dtaux' ] = np .real (np .fft .ifft2 (dtaux_t ))
602627 gc ['dtauy' ] = np .real (np .fft .ifft2 (dtauy_t ))
603628
629+
630+
604631
605632 def separation_shear (self , hsep ):
606633 '''Reduces the computed wind shear perturbation below the
0 commit comments