From b43af77a1eb9756ce2a81ebebd63d8026585ada7 Mon Sep 17 00:00:00 2001 From: John Ragland Date: Mon, 7 Apr 2025 15:51:57 -0700 Subject: [PATCH 1/3] added welch_nan --- src/xrsignal/spectral_analysis.py | 82 +++++++++++++++++++++++++++++-- 1 file changed, 77 insertions(+), 5 deletions(-) diff --git a/src/xrsignal/spectral_analysis.py b/src/xrsignal/spectral_analysis.py index dfa515d..3d861c8 100755 --- a/src/xrsignal/spectral_analysis.py +++ b/src/xrsignal/spectral_analysis.py @@ -170,7 +170,7 @@ def __csd_chunk(data, dim, **kwargs): return np.abs(Pxy_x) -def welch(data, dim, dB=False, **kwargs): +def welch(data, dim, dB=False, nan=False, **kwargs): ''' Estimate power spectral density using welch method For now, an integer number of chunks in PSD dimension is required @@ -183,12 +183,14 @@ def welch(data, dim, dB=False, **kwargs): dimension to calculate PSD over dB : bool if True, return PSD in dB + nan : bool + if True, use welch_nan instead of welch. This throws out segments with NAN and still calculates the PSD ''' if isinstance(data, xr.DataArray): - Sxx = __welch_da(data, dim, dB=dB, **kwargs) + Sxx = __welch_da(data, dim, dB=dB, nan=nan, **kwargs) elif isinstance(data, xr.Dataset): - Sxx = data.map(__welch_da, dim=dim, dB=dB, **kwargs) + Sxx = data.map(__welch_da, dim=dim, dB=dB, nan=nan, **kwargs) else: raise Exception('data must be xr.DataArray or xr.Dataset') @@ -207,6 +209,8 @@ def __welch_chunk(da, dim, **kwargs): **kwargs passed to scipy.signal.welch ''' + # unpack nan kwarg + nan = kwargs.pop('nan', False) # Create new dimensions of PSD object original_dims = list(da.dims) @@ -216,15 +220,22 @@ def __welch_chunk(da, dim, **kwargs): new_dims[original_dims.index(dim)] = f'{dim}_frequency' new_dims.append(dim) + print('nan post map', nan) # Estimate PSD and convert to xarray.DataArray - f, P = signal.welch(da.values, axis=psd_dim_idx, **kwargs) + if nan: + # use welch_nan if nan is True + print('using nan handling') + f, P, _ = welch_nan(da.values, axis=psd_dim_idx, **kwargs) + else: + f, P = signal.welch(da.values, axis=psd_dim_idx, **kwargs) + P = np.expand_dims(P, -1) Px = xr.DataArray(P, dims=new_dims, coords={f'{dim}_frequency': f}) return Px -def __welch_da(da, dim, dB=False, **kwargs): +def __welch_da(da, dim, dB=False, nan=False, **kwargs): ''' Estimate power spectral density using welch method @@ -238,6 +249,8 @@ def __welch_da(da, dim, dB=False, **kwargs): dimension to calculate PSD over dB : bool if True, return PSD in dB + nan : bool + if True, use welch_nan instead of welch ''' ## Parse Kwargs @@ -329,9 +342,68 @@ def __welch_da(da, dim, dB=False, **kwargs): name=f'psd across {dim} dimension') kwargs['dim'] = dim + kwargs['nan'] = nan + + print(kwargs) Pxx = xr.map_blocks(__welch_chunk, da, template=template, kwargs=kwargs) if dB: return 10*np.log10(Pxx) else: return Pxx + + +def welch_nan(x, fs=1.0, window='hann', nperseg=256, noverlap=None, + nfft=None, detrend='constant', return_onesided=True, + scaling='density', axis=-1, average='mean'): + """ + Compute Welch's PSD estimate with NaN handling. + + This function divides data into segments, removes segments containing NaN values, + and then computes the PSD using only valid segments. + + Parameters are the same as scipy.signal.welch + + Returns + ------- + f : ndarray + Array of sample frequencies. + Pxx : ndarray + Power spectral density or power spectrum of x. + n_valid_segments : int + Number of valid segments (without NaN) used in the computation. + """ + # Handle default parameters similar to scipy.signal.welch + if noverlap is None: + noverlap = nperseg // 2 + + # Calculate number of segments and their starting indices + step = nperseg - noverlap + indices = np.arange(0, len(x) - nperseg + 1, step) + + # Create segments and check which ones contain NaN values + segments = np.array([x[i:i+nperseg] for i in indices]) + valid_segments = ~np.isnan(segments).any(axis=1) + + # Count the number of valid segments + n_valid_segments = np.sum(valid_segments) + valid_percent = n_valid_segments / len(segments) + + # If no valid segments, return NaN + if n_valid_segments == 0: + f = np.fft.rfftfreq(nperseg, d=1.0/fs) if return_onesided else np.fft.fftfreq(nperseg, d=1.0/fs) + return f, np.full(len(f), np.nan), 0 + + # Keep only valid segments + valid_data = segments[valid_segments] + + # Flatten into a 1D array with all valid segments concatenated + flattened_valid_data = valid_data.reshape(-1) + + # Call scipy.signal.welch with the valid data + # We set nperseg to the segment length and noverlap to 0 since we've already segmented the data + f, Pxx = signal.welch(flattened_valid_data, fs=fs, window=window, nperseg=nperseg, + noverlap=0, nfft=nfft, detrend=detrend, + return_onesided=return_onesided, scaling=scaling) + + return f, Pxx, valid_percent \ No newline at end of file From 60f9980ae5b0a912ee0009e2f571bd4e4de338bf Mon Sep 17 00:00:00 2001 From: John Ragland Date: Mon, 7 Apr 2025 16:01:24 -0700 Subject: [PATCH 2/3] removed print statements --- src/xrsignal/spectral_analysis.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/xrsignal/spectral_analysis.py b/src/xrsignal/spectral_analysis.py index 3d861c8..0eb455f 100755 --- a/src/xrsignal/spectral_analysis.py +++ b/src/xrsignal/spectral_analysis.py @@ -220,11 +220,8 @@ def __welch_chunk(da, dim, **kwargs): new_dims[original_dims.index(dim)] = f'{dim}_frequency' new_dims.append(dim) - print('nan post map', nan) # Estimate PSD and convert to xarray.DataArray if nan: - # use welch_nan if nan is True - print('using nan handling') f, P, _ = welch_nan(da.values, axis=psd_dim_idx, **kwargs) else: f, P = signal.welch(da.values, axis=psd_dim_idx, **kwargs) @@ -344,7 +341,6 @@ def __welch_da(da, dim, dB=False, nan=False, **kwargs): kwargs['dim'] = dim kwargs['nan'] = nan - print(kwargs) Pxx = xr.map_blocks(__welch_chunk, da, template=template, kwargs=kwargs) if dB: From 8774deb41b54d456f280a3f44c7c7e60f647aba8 Mon Sep 17 00:00:00 2001 From: John Ragland Date: Wed, 9 Apr 2025 07:58:57 -0700 Subject: [PATCH 3/3] updated welch_nan --- src/xrsignal/spectral_analysis.py | 192 ++++++++++++++++++++++++------ 1 file changed, 155 insertions(+), 37 deletions(-) diff --git a/src/xrsignal/spectral_analysis.py b/src/xrsignal/spectral_analysis.py index 0eb455f..6ac9a82 100755 --- a/src/xrsignal/spectral_analysis.py +++ b/src/xrsignal/spectral_analysis.py @@ -222,7 +222,7 @@ def __welch_chunk(da, dim, **kwargs): # Estimate PSD and convert to xarray.DataArray if nan: - f, P, _ = welch_nan(da.values, axis=psd_dim_idx, **kwargs) + f, P = welch_nan(da.values, axis=psd_dim_idx, **kwargs) else: f, P = signal.welch(da.values, axis=psd_dim_idx, **kwargs) @@ -348,17 +348,47 @@ def __welch_da(da, dim, dB=False, nan=False, **kwargs): else: return Pxx - -def welch_nan(x, fs=1.0, window='hann', nperseg=256, noverlap=None, - nfft=None, detrend='constant', return_onesided=True, - scaling='density', axis=-1, average='mean'): +def welch_nan(x, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None, + detrend='constant', return_onesided=True, scaling='density', + axis=-1, average='mean'): """ - Compute Welch's PSD estimate with NaN handling. + Estimate power spectral density using Welch's method, ignoring NaN values. - This function divides data into segments, removes segments containing NaN values, - and then computes the PSD using only valid segments. + This function behaves exactly like scipy.signal.welch but handles NaN values + by ignoring segments containing NaN when computing the average. - Parameters are the same as scipy.signal.welch + Parameters + ---------- + x : array_like + Time series of measurement values + fs : float, optional + Sampling frequency of the `x` time series. Defaults to 1.0. + window : str or tuple or array_like, optional + Desired window to use. Defaults to 'hann'. + nperseg : int, optional + Length of each segment. Defaults to 256. + noverlap : int, optional + Number of points to overlap between segments. Defaults to `nperseg // 2`. + nfft : int, optional + Length of the FFT used. If `None`, the default is `nperseg`. + detrend : str or function or `False`, optional + Specifies how to detrend each segment. If `detrend` is a string, it is + passed as the `type` argument to the `detrend` function. If it is a + function, it takes a segment and returns a detrended segment. If `False`, + no detrending is done. Defaults to 'constant'. + return_onesided : bool, optional + If `True`, return a one-sided spectrum for real data. If `False` return + a two-sided spectrum. If `x` is complex, the default is `False`. + scaling : { 'density', 'spectrum' }, optional + Selects between computing the power spectral density ('density') where + `Pxx` has units of V**2/Hz and computing the power spectrum ('spectrum') + where `Pxx` has units of V**2, if `x` is measured in V and fs is measured + in Hz. Defaults to 'density'. + axis : int, optional + Axis along which the periodogram is computed; the default is over the + last axis (i.e., axis=-1). + average : { 'mean', 'median' }, optional + Method to use when averaging periodograms. Defaults to 'mean'. Returns ------- @@ -366,40 +396,128 @@ def welch_nan(x, fs=1.0, window='hann', nperseg=256, noverlap=None, Array of sample frequencies. Pxx : ndarray Power spectral density or power spectrum of x. - n_valid_segments : int - Number of valid segments (without NaN) used in the computation. """ - # Handle default parameters similar to scipy.signal.welch + # Convert x to numpy array + x = np.asarray(x) + + # Check if there are any NaN values + if not np.any(np.isnan(x)): + # If no NaN values, use the original welch function + return signal.welch(x, fs=fs, window=window, nperseg=nperseg, + noverlap=noverlap, nfft=nfft, detrend=detrend, + return_onesided=return_onesided, scaling=scaling, + axis=axis, average=average) + + # Handle negative axis + if axis < 0: + axis = x.ndim + axis + + # Set default parameters + if nperseg is None: + nperseg = min(256, x.shape[axis]) + if noverlap is None: noverlap = nperseg // 2 - - # Calculate number of segments and their starting indices - step = nperseg - noverlap - indices = np.arange(0, len(x) - nperseg + 1, step) - # Create segments and check which ones contain NaN values - segments = np.array([x[i:i+nperseg] for i in indices]) - valid_segments = ~np.isnan(segments).any(axis=1) + if nfft is None: + nfft = nperseg + + # Get the window + if isinstance(window, str) or isinstance(window, tuple): + win = signal.get_window(window, nperseg) + else: + win = np.asarray(window) + if len(win) != nperseg: + raise ValueError('window must have length of nperseg') + + # Determine if input is complex + is_complex = np.iscomplexobj(x) - # Count the number of valid segments - n_valid_segments = np.sum(valid_segments) - valid_percent = n_valid_segments / len(segments) - - # If no valid segments, return NaN - if n_valid_segments == 0: - f = np.fft.rfftfreq(nperseg, d=1.0/fs) if return_onesided else np.fft.fftfreq(nperseg, d=1.0/fs) - return f, np.full(len(f), np.nan), 0 + # Calculate frequencies for return array + if return_onesided and not is_complex: + # Real input, one-sided frequency range + freqs = np.fft.rfftfreq(nfft, 1.0/fs) + else: + # Complex input or two-sided frequency range + if return_onesided and is_complex: + # For complex input with return_onesided=True, scipy.signal.welch + # would issue a warning and compute the full spectrum + import warnings + warnings.warn('return_onesided=True is ignored for complex input. ' + 'Computing two-sided spectrum.') + freqs = np.fft.fftfreq(nfft, 1.0/fs) - # Keep only valid segments - valid_data = segments[valid_segments] + # Init periodogram array + segment_psds = [] - # Flatten into a 1D array with all valid segments concatenated - flattened_valid_data = valid_data.reshape(-1) + # Calculate step size between segments + step = nperseg - noverlap - # Call scipy.signal.welch with the valid data - # We set nperseg to the segment length and noverlap to 0 since we've already segmented the data - f, Pxx = signal.welch(flattened_valid_data, fs=fs, window=window, nperseg=nperseg, - noverlap=0, nfft=nfft, detrend=detrend, - return_onesided=return_onesided, scaling=scaling) + # For each slice of the input data, calculate a periodogram + ind = 0 + while ind + nperseg <= x.shape[axis]: + # Extract segment + segment_slice = [slice(None)] * x.ndim + segment_slice[axis] = slice(ind, ind + nperseg) + segment = x[tuple(segment_slice)].copy() + + # Check if segment contains NaN + if not np.any(np.isnan(segment)): + # Detrend if needed + if detrend != False: + segment = signal.detrend(segment, type=detrend, axis=axis) + + # Apply window (broadcasting to the correct shape) + win_shape = [1] * segment.ndim + win_shape[axis] = len(win) + segment = segment * win.reshape(win_shape) + + # Compute FFT + if return_onesided and not is_complex: + # Real input, one-sided FFT + fft_data = np.fft.rfft(segment, n=nfft, axis=axis) + else: + # Complex input or two-sided FFT + fft_data = np.fft.fft(segment, n=nfft, axis=axis) + + # Compute PSD based on scaling + if scaling == 'density': + # Power spectral density + segment_psd = abs(fft_data)**2 / (fs * (win**2).sum()) + else: # scaling == 'spectrum' + # Power spectrum + segment_psd = abs(fft_data)**2 / win.sum()**2 + + # Apply one-sided scaling for real data + if return_onesided and not is_complex: + # Multiply by 2 for one-sided (except at DC and Nyquist) + if nfft % 2 == 0: # Even nfft, Nyquist present + segment_psd[..., 1:-1] *= 2 + else: # Odd nfft, Nyquist not present + segment_psd[..., 1:] *= 2 + + segment_psds.append(segment_psd) + + # Move to next segment + ind += step + + # If no valid segments found, return NaN array + if not segment_psds: + if return_onesided and not is_complex: + Pxx = np.full(freqs.shape, np.nan) + else: + Pxx = np.full(freqs.shape, np.nan) + return freqs, Pxx + + # Stack periodograms for averaging + segment_psds = np.stack(segment_psds, axis=0) + + # Average the periodograms + if average == 'mean': + Pxx = np.mean(segment_psds, axis=0) + elif average == 'median': + Pxx = np.median(segment_psds, axis=0) + else: + raise ValueError(f"Unknown average: {average}") - return f, Pxx, valid_percent \ No newline at end of file + return freqs, Pxx \ No newline at end of file