Skip to content

Commit bff670e

Browse files
committed
shear patch
1 parent 8b2e09c commit bff670e

File tree

1 file changed

+43
-16
lines changed

1 file changed

+43
-16
lines changed

aeolis/shear.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -576,31 +576,58 @@ def compute_shear(self, u0, nfilter=(1., 2.)):
576576

577577
# Arrays in Fourier
578578
k = np.sqrt(kx**2 + ky**2)
579-
sigma = np.sqrt(1j * L * kx * z0new /l)
580579

581580

582581
time_start_perturbation = time.time()
583-
582+
583+
584584
# Shear stress perturbation
585-
# Use safe division to avoid zero/invalid values at kx=0 or k=0
586-
k_safe = np.where(k == 0, 1.0, k)
587-
kx_safe = np.where(kx == 0, 1.0, kx)
588-
589-
dtaux_t = hs * kx**2 / k_safe * 2 / ul**2 * \
590-
(-1. + (2. * np.log(l/z0new) + k**2/kx_safe**2) * sigma * \
591-
sc_kv(1., 2. * sigma) / sc_kv(0., 2. * sigma))
585+
# Use masked computation to avoid division by zero and invalid special-function calls.
586+
# Build boolean mask for valid Fourier modes where formula is defined.
587+
valid = (k != 0) & (kx != 0)
588+
589+
# Pre-allocate zero arrays for Fourier-domain shear perturbations
590+
dtaux_t = np.zeros_like(hs, dtype=complex)
591+
dtauy_t = np.zeros_like(hs, dtype=complex)
592+
593+
if np.any(valid):
594+
# Extract valid-mode arrays
595+
k_v = k[valid]
596+
kx_v = kx[valid]
597+
ky_v = ky[valid]
598+
hs_v = hs[valid]
599+
600+
# z0new can be scalar or array; index accordingly
601+
if np.size(z0new) == 1:
602+
z0_v = z0new
603+
else:
604+
z0_v = z0new[valid]
592605

593-
dtauy_t = hs * kx * ky / k_safe * 2 / ul**2 * \
594-
2. * np.sqrt(2.) * sigma * sc_kv(1., 2. * np.sqrt(2.) * sigma)
606+
# compute sigma on valid modes
607+
sigma_v = np.sqrt(1j * L * kx_v * z0_v / l)
595608

596-
# Zero out invalid regions (kx=0 or k=0) where formulation is not valid
597-
invalid_mask = (k == 0) | (kx == 0)
598-
dtaux_t[invalid_mask] = 0.
599-
dtauy_t[invalid_mask] = 0.
600-
609+
# Evaluate Bessel K functions on valid arguments only
610+
kv0 = sc_kv(0., 2. * sigma_v)
611+
kv1 = sc_kv(1., 2. * sigma_v)
612+
613+
# main x-direction perturbation (vectorized on valid indices)
614+
term_x = -1. + (2. * np.log(l / z0_v) + (k_v**2) / (kx_v**2)) * sigma_v * (kv1 / kv0)
615+
dtaux_v = hs_v * (kx_v**2) / k_v * 2. / ul**2 * term_x
616+
617+
# y-direction perturbation (also vectorized)
618+
kv1_y = sc_kv(1., 2. * np.sqrt(2.) * sigma_v)
619+
dtauy_v = hs_v * (kx_v * ky_v) / k_v * 2. / ul**2 * 2. * np.sqrt(2.) * sigma_v * (kv1_y)
620+
621+
# store back into full arrays (other entries remain zero)
622+
dtaux_t[valid] = dtaux_v
623+
dtauy_t[valid] = dtauy_v
624+
625+
# invalid modes remain 0 (physically reasonable for k=0 or kx=0)
601626
gc['dtaux'] = np.real(np.fft.ifft2(dtaux_t))
602627
gc['dtauy'] = np.real(np.fft.ifft2(dtauy_t))
603628

629+
630+
604631

605632
def separation_shear(self, hsep):
606633
'''Reduces the computed wind shear perturbation below the

0 commit comments

Comments
 (0)