Skip to content

Commit faa36e7

Browse files
CopilotSierd
andcommitted
Handle division by zero at DC component in FFT shear
The previous fix included the DC component (kx=0, ky=0) in the frequency arrays, which caused division by zero errors in the shear calculations. Fixed by: - Using safe division with np.where to replace zeros with 1.0 temporarily - Explicitly setting DC component of perturbations to 0 after calculation - Applying same fix to filter_highfrequencies function The DC component represents the mean value and doesn't contribute to perturbations, so setting it to zero is physically correct. Co-authored-by: Sierd <14054272+Sierd@users.noreply.github.com>
1 parent 31bd01c commit faa36e7

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

aeolis/shear.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -582,15 +582,22 @@ def compute_shear(self, u0, nfilter=(1., 2.)):
582582
time_start_perturbation = time.time()
583583

584584
# Shear stress perturbation
585+
# Avoid division by zero at DC component (kx=0, ky=0)
586+
# Set a small value to avoid division by zero, then set DC to 0 after calculation
587+
k_safe = np.where(k == 0, 1.0, k)
588+
kx_safe = np.where(kx == 0, 1.0, kx)
585589

586-
dtaux_t = hs * kx**2 / k * 2 / ul**2 * \
587-
(-1. + (2. * np.log(l/z0new) + k**2/kx**2) * sigma * \
590+
dtaux_t = hs * kx**2 / k_safe * 2 / ul**2 * \
591+
(-1. + (2. * np.log(l/z0new) + k**2/kx_safe**2) * sigma * \
588592
sc_kv(1., 2. * sigma) / sc_kv(0., 2. * sigma))
589593

590594

591-
dtauy_t = hs * kx * ky / k * 2 / ul**2 * \
595+
dtauy_t = hs * kx * ky / k_safe * 2 / ul**2 * \
592596
2. * np.sqrt(2.) * sigma * sc_kv(1., 2. * np.sqrt(2.) * sigma) / sc_kv(0., 2. * np.sqrt(2.) * sigma)
593597

598+
# Set DC component to zero (no perturbation at zero frequency)
599+
dtaux_t[k == 0] = 0.
600+
dtauy_t[k == 0] = 0.
594601

595602
gc['dtaux'] = np.real(np.fft.ifft2(dtaux_t))
596603
gc['dtauy'] = np.real(np.fft.ifft2(dtauy_t))
@@ -668,8 +675,11 @@ def filter_highfrequenies(self, kx, ky, hs, nfilter=(1, 2)):
668675
if nfilter is not None:
669676
n1 = np.min(nfilter)
670677
n2 = np.max(nfilter)
671-
px = 2 * np.pi / self.cgrid['dx'] / np.abs(kx)
672-
py = 2 * np.pi / self.cgrid['dy'] / np.abs(ky)
678+
# Avoid division by zero at DC component (kx=0, ky=0)
679+
kx_safe = np.where(kx == 0, 1.0, kx)
680+
ky_safe = np.where(ky == 0, 1.0, ky)
681+
px = 2 * np.pi / self.cgrid['dx'] / np.abs(kx_safe)
682+
py = 2 * np.pi / self.cgrid['dy'] / np.abs(ky_safe)
673683
s1 = n1 / np.log(1. / .01 - 1.)
674684
s2 = -n2 / np.log(1. / .99 - 1.)
675685
f1 = 1. / (1. + np.exp(-(px + n1 - n2) / s1))

0 commit comments

Comments
 (0)