Skip to content

Commit 39934d5

Browse files
Merge pull request #272 from openearth/main
Fix ustarn calculation: initialization and FFT shear formula bugs (#265)
2 parents b3fb74d + 506064d commit 39934d5

File tree

2 files changed

+54
-18
lines changed

2 files changed

+54
-18
lines changed

aeolis/shear.py

Lines changed: 52 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -389,8 +389,8 @@ def compute_shear(self, u0, nfilter=(1., 2.)):
389389
return
390390

391391
ny, nx = gc['z'].shape
392-
kx, ky = np.meshgrid(2. * np.pi * np.fft.fftfreq(nx+1, gc['dx'])[1:],
393-
2. * np.pi * np.fft.fftfreq(ny+1, gc['dy'])[1:])
392+
kx, ky = np.meshgrid(2. * np.pi * np.fft.fftfreq(nx, gc['dx']),
393+
2. * np.pi * np.fft.fftfreq(ny, gc['dy']))
394394

395395
hs = np.fft.fft2(gc['z'])
396396
hs = self.filter_highfrequenies(kx, ky, hs, nfilter)
@@ -421,25 +421,58 @@ def compute_shear(self, u0, nfilter=(1., 2.)):
421421

422422
# Arrays in Fourier
423423
k = np.sqrt(kx**2 + ky**2)
424-
sigma = np.sqrt(1j * L * kx * z0new /l)
425424

426425

427426
time_start_perturbation = time.time()
428-
429-
# Shear stress perturbation
430-
431-
dtaux_t = hs * kx**2 / k * 2 / ul**2 * \
432-
(-1. + (2. * np.log(l/z0new) + k**2/kx**2) * sigma * \
433-
sc_kv(1., 2. * sigma) / sc_kv(0., 2. * sigma))
434427

435-
436-
dtauy_t = hs * kx * ky / k * 2 / ul**2 * \
437-
2. * np.sqrt(2.) * sigma * sc_kv(1., 2. * np.sqrt(2.) * sigma)
438428

439-
429+
# Shear stress perturbation
430+
# Use masked computation to avoid division by zero and invalid special-function calls.
431+
# Build boolean mask for valid Fourier modes where formula is defined.
432+
valid = (k != 0) & (kx != 0)
433+
434+
# Pre-allocate zero arrays for Fourier-domain shear perturbations
435+
dtaux_t = np.zeros_like(hs, dtype=complex)
436+
dtauy_t = np.zeros_like(hs, dtype=complex)
437+
438+
if np.any(valid):
439+
# Extract valid-mode arrays
440+
k_v = k[valid]
441+
kx_v = kx[valid]
442+
ky_v = ky[valid]
443+
hs_v = hs[valid]
444+
445+
# z0new can be scalar or array; index accordingly
446+
if np.size(z0new) == 1:
447+
z0_v = z0new
448+
else:
449+
z0_v = z0new[valid]
450+
451+
# compute sigma on valid modes
452+
sigma_v = np.sqrt(1j * L * kx_v * z0_v / l)
453+
454+
# Evaluate Bessel K functions on valid arguments only
455+
kv0 = sc_kv(0., 2. * sigma_v)
456+
kv1 = sc_kv(1., 2. * sigma_v)
457+
458+
# main x-direction perturbation (vectorized on valid indices)
459+
term_x = -1. + (2. * np.log(l / z0_v) + (k_v**2) / (kx_v**2)) * sigma_v * (kv1 / kv0)
460+
dtaux_v = hs_v * (kx_v**2) / k_v * 2. / ul**2 * term_x
461+
462+
# y-direction perturbation (also vectorized)
463+
kv1_y = sc_kv(1., 2. * np.sqrt(2.) * sigma_v)
464+
dtauy_v = hs_v * (kx_v * ky_v) / k_v * 2. / ul**2 * 2. * np.sqrt(2.) * sigma_v * (kv1_y)
465+
466+
# store back into full arrays (other entries remain zero)
467+
dtaux_t[valid] = dtaux_v
468+
dtauy_t[valid] = dtauy_v
469+
470+
# invalid modes remain 0 (physically reasonable for k=0 or kx=0)
440471
gc['dtaux'] = np.real(np.fft.ifft2(dtaux_t))
441472
gc['dtauy'] = np.real(np.fft.ifft2(dtauy_t))
442473

474+
475+
443476

444477
def separation_shear(self, hsep):
445478
'''Reduces the computed wind shear perturbation below the
@@ -513,8 +546,11 @@ def filter_highfrequenies(self, kx, ky, hs, nfilter=(1, 2)):
513546
if nfilter is not None:
514547
n1 = np.min(nfilter)
515548
n2 = np.max(nfilter)
516-
px = 2 * np.pi / self.cgrid['dx'] / np.abs(kx)
517-
py = 2 * np.pi / self.cgrid['dy'] / np.abs(ky)
549+
# Avoid division by zero at DC component (kx=0, ky=0)
550+
kx_safe = np.where(kx == 0, 1.0, kx)
551+
ky_safe = np.where(ky == 0, 1.0, ky)
552+
px = 2 * np.pi / self.cgrid['dx'] / np.abs(kx_safe)
553+
py = 2 * np.pi / self.cgrid['dy'] / np.abs(ky_safe)
518554
s1 = n1 / np.log(1. / .01 - 1.)
519555
s2 = -n2 / np.log(1. / .99 - 1.)
520556
f1 = 1. / (1. + np.exp(-(px + n1 - n2) / s1))
@@ -727,4 +763,4 @@ def interpolate(self, x, y, z, xi, yi, z0):
727763

728764

729765

730-
766+

aeolis/wind.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,8 @@ def interpolate(s, p, t):
153153
s = velocity_stress(s,p)
154154

155155
s['ustar0'] = s['ustar'].copy()
156-
s['ustars0'] = s['ustar'].copy()
157-
s['ustarn0'] = s['ustar'].copy()
156+
s['ustars0'] = s['ustars'].copy()
157+
s['ustarn0'] = s['ustarn'].copy()
158158

159159
s['tau0'] = s['tau'].copy()
160160
s['taus0'] = s['taus'].copy()

0 commit comments

Comments
 (0)