diff --git a/pyproject.toml b/pyproject.toml index bdfecd610..22c164250 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -294,25 +294,7 @@ dict = ["scitex-dict>=0.1.2"] # DSP Module - Digital Signal Processing # Use: pip install scitex[dsp] -dsp = [ - "scipy", - "sounddevice", - "matplotlib", - "joblib", - "tensorpac", - "ruamel.yaml", - "h5py", - "readchar", - "xarray", - "seaborn", - # # Heavy dependencies handled by _AVAILABLE flags - # "mne", - # "torch", - # "ripple_detection", - # "torchsummary", - # "julius", - # "torchaudio", -] +dsp = ["scitex-dsp>=0.1.0"] # DevTools Module - scitex.dev code analysis and development tools # Use: pip install scitex[devtools] diff --git a/src/scitex/dsp/README.md b/src/scitex/dsp/README.md deleted file mode 100755 index 4a7ad35d9..000000000 --- a/src/scitex/dsp/README.md +++ /dev/null @@ -1,147 +0,0 @@ - - - -# [`scitex.dsp`](https://github.com/ywatanabe1989/scitex/tree/main/src/scitex/dsp/) - -## Overview -The `scitex.dsp` module provides Digital Signal Processing (DSP) utilities written in **PyTorch**, optimized for **CUDA** devices when available. This module offers efficient implementations of various DSP algorithms and techniques. - -## Installation -```bash -pip install scitex -``` - -## Features -- PyTorch-based implementations for GPU acceleration -- Wavelet transforms and analysis -- Filtering operations (e.g., bandpass, lowpass, highpass) -- Spectral analysis tools -- Time-frequency analysis utilities -- Signal generation and manipulation functions -- Phase-Amplitude Coupling (PAC) analysis -- Modulation Index calculation -- Hilbert transform -- Power Spectral Density (PSD) estimation -- Resampling utilities - -## Galleries -
- - - - -
- -
- - - - -
- -## Quick Start -```python -# Parameters -SRC_FS = 1024 # Source sampling frequency -TGT_FS = 512 # Target sampling frequency -FREQS_HZ = [10, 30, 100] # Frequencies in Hz for periodic signals -LOW_HZ = 20 # Low frequency for bandpass filter -HIGH_HZ = 50 # High frequency for bandpass filter -SIGMA = 10 # Sigma for Gaussian filter -SIG_TYPES = [ - "uniform", - "gauss", - "periodic", - "chirp", - "ripple", - "meg", - "tensorpac", -] # Available signal types - - -# Demo Signal -xx, tt, fs = scitex.dsp.demo_sig( - t_sec=T_SEC, fs=SRC_FS, freqs_hz=FREQS_HZ, sig_type="chirp" -) -# xx.shape (batch_size, n_chs, seq_len) -# xx.shape (batch_size, n_chs, n_segments, seq_len) # when sig_type is "tensorpac" or "pac" - - -# # Various data types are automatically handled: -# xx = torch.tensor(xx).float() -# xx = torch.tensor(xx).float().cuda() -# xx = np.array(xx) -# xx = pd.DataFrame(xx) - -# Normalization -xx_norm = scitex.dsp.norm.z(xx) -xx_minmax = scitex.dsp.norm.minmax(xx) - -# Resampling -xx_resampled = scitex.dsp.resample(xx, fs, TGT_FS) - -# Noise addition -xx_gauss = scitex.dsp.add_noise.gauss(xx) -xx_white = scitex.dsp.add_noise.white(xx) -xx_pink = scitex.dsp.add_noise.pink(xx) -xx_brown = scitex.dsp.add_noise.brown(xx) - -# Filtering -xx_filted_bandpass = scitex.dsp.filt.bandpass(xx, fs, low_hz=LOW_HZ, high_hz=HIGH_HZ) -xx_filted_bandstop = scitex.dsp.filt.bandstop(xx, fs, low_hz=LOW_HZ, high_hz=HIGH_HZ) -xx_filted_gauss = scitex.dsp.filt.gauss(xx, sigma=SIGMA) - -# Hilbert Transformation -phase, amplitude = scitex.dsp.hilbert(xx) # or envelope - -# Wavelet Transformation -wavelet_coef, wavelet_freqs = scitex.dsp.wavelet(xx, fs) - -# Power Spetrum Density -psd, psd_freqs = scitex.dsp.psd(xx, fs) - -# Phase-Amplitude Coupling -pac, freqs_pha, freqs_amp = scitex.dsp.pac(xx, fs) # This process is computationally intensive. Please monitor RAM/VRAM usage. -``` - -## API Reference -- `scitex.dsp.wavelet_transform(signal, wavelet, level)`: Performs wavelet transform -- `scitex.dsp.bandpass_filter(signal, lowcut, highcut, fs)`: Applies bandpass filter -- `scitex.dsp.lowpass_filter(signal, cutoff, fs)`: Applies lowpass filter -- `scitex.dsp.highpass_filter(signal, cutoff, fs)`: Applies highpass filter -- `scitex.dsp.spectrogram(signal, fs, nperseg)`: Computes spectrogram -- `scitex.dsp.stft(signal, fs, nperseg)`: Performs Short-Time Fourier Transform -- `scitex.dsp.istft(stft, fs, nperseg)`: Performs Inverse Short-Time Fourier Transform -- `scitex.dsp.chirp(t, f0, f1, method)`: Generates chirp signal -- `scitex.dsp.hilbert(signal)`: Performs Hilbert transform -- `scitex.dsp.phase_amplitude_coupling(signal, fs)`: Calculates Phase-Amplitude Coupling -- `scitex.dsp.Modulation_index(signal, fs)`: Computes Modulation Index -- `scitex.dsp.psd(signal, fs)`: Estimates Power Spectral Density -- `scitex.dsp.resample(signal, orig_fs, new_fs)`: Resamples signal to new frequency -- `scitex.dsp.add_noise(signal, snr)`: Adds noise to signal with specified SNR - -## Use Cases -- Audio signal processing -- Biomedical signal analysis -- Vibration analysis -- Communication systems -- Radar and sonar signal processing -- Neuroscience data analysis - -## Performance -The `scitex.dsp` module leverages PyTorch's GPU acceleration capabilities, providing significant speedups for large-scale signal processing tasks when run on CUDA-enabled devices. - -## Contributing -Contributions to improve `scitex.dsp` are welcome. Please submit pull requests or open issues on the GitHub repository. - -## License -This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details. - -## Contact -Yusuke Watanabe (ywatanabe@scitex.ai) - -For more information and updates, please visit the [scitex GitHub repository](https://github.com/ywatanabe1989/scitex). diff --git a/src/scitex/dsp/__init__.py b/src/scitex/dsp/__init__.py index 2e6f66f6f..e8e37f736 100755 --- a/src/scitex/dsp/__init__.py +++ b/src/scitex/dsp/__init__.py @@ -1,86 +1,13 @@ -#!/usr/bin/env python3 -"""Scitex dsp module.""" +"""SciTeX dsp — thin compatibility shim for scitex-dsp.""" -import warnings +import sys as _sys -# Import example, params, norm, reference, filt, and add_noise modules as submodules -from . import add_noise, example, filt, norm, params, reference - -# Core imports that should always work -from ._crop import crop -from ._demo_sig import demo_sig -from ._detect_ripples import ( - _calc_relative_peak_position, - _drop_ripples_at_edges, - _find_events, - _preprocess, - _sort_columns, - detect_ripples, -) -from ._ensure_3d import ensure_3d -from ._hilbert import hilbert -from ._modulation_index import _reshape, modulation_index -from ._pac import pac -from ._psd import band_powers, psd -from ._resample import resample -from ._time import time -from ._transform import to_segments, to_sktime_df -from ._wavelet import wavelet - -# Try to import audio-related functions that require PortAudio try: - from ._listen import list_and_select_device - - _audio_available = True -except (ImportError, OSError): - warnings.warn( - "Audio functionality unavailable: PortAudio library not found. " - "Install PortAudio to use audio features (e.g., sudo apt-get install portaudio19-dev)", - ImportWarning, - ) - list_and_select_device = None - _audio_available = False - -# Try to import MNE-related functions -try: - from ._mne import get_eeg_pos - - _mne_available = True -except ImportError: - warnings.warn( - "MNE functionality unavailable. Install MNE-Python to use EEG position features.", - ImportWarning, - ) - get_eeg_pos = None - _mne_available = False - -__all__ = [ - "_calc_relative_peak_position", - "_drop_ripples_at_edges", - "_find_events", - "_preprocess", - "_reshape", - "_sort_columns", - "add_noise", - "band_powers", - "crop", - "demo_sig", - "detect_ripples", - "ensure_3d", - "example", - "filt", - "get_eeg_pos", - "hilbert", - "list_and_select_device", - "modulation_index", - "norm", - "pac", - "params", - "psd", - "reference", - "resample", - "time", - "to_segments", - "to_sktime_df", - "wavelet", -] + import scitex_dsp as _real +except ImportError as _e: + raise ImportError( + "scitex.dsp requires the 'scitex-dsp' package. " + "Install with: pip install scitex[dsp] (or: pip install scitex-dsp)" + ) from _e + +_sys.modules[__name__] = _real diff --git a/src/scitex/dsp/_crop.py b/src/scitex/dsp/_crop.py deleted file mode 100755 index a651387cf..000000000 --- a/src/scitex/dsp/_crop.py +++ /dev/null @@ -1,125 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "ywatanabe (2024-11-02 22:50:46)" -# File: ./scitex_repo/src/scitex/dsp/_crop.py - -import numpy as np - - -def crop(sig_2d, window_length, overlap_factor=0.0, axis=-1, time=None): - """ - Crops the input signal into overlapping windows of a specified length, - allowing for an arbitrary axis and considering a time vector. - - Parameters: - - sig_2d (numpy.ndarray): The input sig_2d array to be cropped. Can be multi-dimensional. - - window_length (int): The length of each window to crop the sig_2d into. - - overlap_factor (float): The fraction of the window that consecutive windows overlap. For example, an overlap_factor of 0.5 means 50% overlap. - - axis (int): The time axis along which to crop the sig_2d. - - time (numpy.ndarray): The time vector associated with the signal. Its length should match the signal's length along the cropping axis. - - Returns: - - cropped_windows (numpy.ndarray): The cropped signal windows. The shape depends on the input shape and the specified axis. - """ - # Ensure axis is in a valid range - if axis < 0: - axis += sig_2d.ndim - if axis >= sig_2d.ndim or axis < 0: - raise ValueError("Invalid axis. Axis out of range for sig_2d dimensions.") - - if time is not None: - # Validate the length of the time vector against the signal's dimension - if sig_2d.shape[axis] != len(time): - raise ValueError( - "Length of time vector does not match signal's dimension along the specified axis." - ) - - # Move the target axis to the last position - axes = np.arange(sig_2d.ndim) - axes[axis], axes[-1] = axes[-1], axes[axis] - sig_2d_permuted = np.transpose(sig_2d, axes) - - # Compute the number of windows and the step size - seq_len = sig_2d_permuted.shape[-1] - step = int(window_length * (1 - overlap_factor)) - n_windows = max( - 1, ((seq_len - window_length) // step + 1) - ) # Ensure at least 1 window - - # Crop the sig_2d into windows - cropped_windows = [] - cropped_times = [] - for i in range(n_windows): - start = i * step - end = start + window_length - cropped_windows.append(sig_2d_permuted[..., start:end]) - if time is not None: - cropped_times.append(time[start:end]) - - # Convert list of windows back to numpy array - cropped_windows = np.array(cropped_windows) - cropped_times = np.array(cropped_times) - - # Move the last axis back to its original position if necessary - if axis != sig_2d.ndim - 1: - # Compute the inverse permutation - inv_axes = np.argsort(axes) - cropped_windows = np.transpose(cropped_windows, axes=inv_axes) - - if time is None: - return cropped_windows - else: - return cropped_windows, cropped_times - - -def main(): - import random - - FS = 128 - N_CHS = 19 - RECORD_S = 13 - WINDOW_S = 2 - FACTOR = 0.5 - - # To pts - record_pts = int(RECORD_S * FS) - window_pts = int(WINDOW_S * FS) - - # Demo signal - sig2d = np.random.rand(N_CHS, record_pts) - time = np.arange(record_pts) / FS - - # Main - xx, tt = crop(sig2d, window_pts, overlap_factor=FACTOR, time=time) - - print(f"sig2d.shape: {sig2d.shape}") - print(f"xx.shape: {xx.shape}") - - # Validation - i_seg = random.randint(0, len(xx) - 1) - start = int(i_seg * window_pts * FACTOR) - end = start + window_pts - assert np.allclose(sig2d[:, start:end], xx[i_seg]) - - -if __name__ == "__main__": - # parser = argparse.ArgumentParser(description='') - # import argparse - # # Argument Parser - import sys - - import matplotlib.pyplot as plt - - import scitex - - # parser.add_argument('--var', '-v', type=int, default=1, help='') - # parser.add_argument('--flag', '-f', action='store_true', default=False, help='') - # args = parser.parse_args() - # Main - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start( - sys, plt, verbose=False - ) - main() - scitex.session.close(CONFIG, verbose=False, notify=False) - -# EOF diff --git a/src/scitex/dsp/_demo_sig.py b/src/scitex/dsp/_demo_sig.py deleted file mode 100755 index 348d6ddc3..000000000 --- a/src/scitex/dsp/_demo_sig.py +++ /dev/null @@ -1,380 +0,0 @@ -#!/usr/bin/env python3 -# Time-stamp: "2024-11-06 01:45:32 (ywatanabe)" -# File: ./scitex_repo/src/scitex/dsp/_demo_sig.py - -import random -import sys -import warnings - -import matplotlib.pyplot as plt -import numpy as np -from scipy.signal import chirp - -try: - import mne - from mne.datasets import sample - - MNE_AVAILABLE = True -except ImportError: - MNE_AVAILABLE = False - mne = None - sample = None - -try: - from ripple_detection.simulate import simulate_LFP, simulate_time - - RIPPLE_DETECTION_AVAILABLE = True -except ImportError: - RIPPLE_DETECTION_AVAILABLE = False - simulate_LFP = None - simulate_time = None - -try: - from tensorpac.signals import pac_signals_wavelet - - TENSORPAC_AVAILABLE = True -except ImportError: - TENSORPAC_AVAILABLE = False - pac_signals_wavelet = None - -from scitex.io import load_configs - -# Config -CONFIG = load_configs(verbose=False) - - -def _check_mne(): - if not MNE_AVAILABLE: - raise ImportError( - "MNE-Python is not installed. Please install with: pip install mne" - ) - - -def _check_ripple_detection(): - if not RIPPLE_DETECTION_AVAILABLE: - raise ImportError( - "ripple_detection is not installed. Please install with: pip install ripple_detection" - ) - - -def _check_tensorpac(): - if not TENSORPAC_AVAILABLE: - raise ImportError( - "tensorpac is not installed. Please install with: pip install tensorpac" - ) - - -# Functions -def demo_sig( - sig_type="periodic", - batch_size=8, - n_chs=19, - n_segments=20, - t_sec=4, - fs=512, - freqs_hz=None, - verbose=False, -): - """ - Generate demo signals for various signal types. - - Parameters - ---------- - sig_type : str, optional - Type of signal to generate. Options are "uniform", "gauss", "periodic", "chirp", "ripple", "meg", "tensorpac", "pac". - Default is "periodic". - batch_size : int, optional - Number of batches to generate. Default is 8. - n_chs : int, optional - Number of channels. Default is 19. - n_segments : int, optional - Number of segments for tensorpac and pac signals. Default is 20. - t_sec : float, optional - Duration of the signal in seconds. Default is 4. - fs : int, optional - Sampling frequency in Hz. Default is 512. - freqs_hz : list or None, optional - List of frequencies in Hz for periodic signals. If None, random frequencies will be used. - verbose : bool, optional - If True, print additional information. Default is False. - - Returns - ------- - tuple - A tuple containing: - - np.ndarray: Generated signal(s) with shape (batch_size, n_chs, time_samples) or (batch_size, n_chs, n_segments, time_samples) for tensorpac and pac signals. - - np.ndarray: Time array. - - int: Sampling frequency. - """ - assert sig_type in [ - "uniform", - "gauss", - "periodic", - "chirp", - "ripple", - "meg", - "tensorpac", - "pac", - ] - tt = np.linspace(0, t_sec, int(t_sec * fs), endpoint=False) - - if sig_type == "uniform": - return ( - np.random.uniform(low=-0.5, high=0.5, size=(batch_size, n_chs, len(tt))), - tt, - fs, - ) - - elif sig_type == "gauss": - return np.random.randn(batch_size, n_chs, len(tt)), tt, fs - - elif sig_type == "meg": - return ( - _demo_sig_meg( - batch_size=batch_size, - n_chs=n_chs, - t_sec=t_sec, - fs=fs, - verbose=verbose, - ).astype(np.float32)[..., : len(tt)], - tt, - fs, - ) - - elif sig_type == "tensorpac": - xx, tt = _demo_sig_tensorpac( - batch_size=batch_size, - n_chs=n_chs, - n_segments=n_segments, - t_sec=t_sec, - fs=fs, - ) - return xx.astype(np.float32)[..., : len(tt)], tt, fs - - elif sig_type == "pac": - xx = _demo_sig_pac( - batch_size=batch_size, - n_chs=n_chs, - n_segments=n_segments, - t_sec=t_sec, - fs=fs, - ) - return xx.astype(np.float32)[..., : len(tt)], tt, fs - - else: - fn_1d = { - "periodic": _demo_sig_periodic_1d, - "chirp": _demo_sig_chirp_1d, - "ripple": _demo_sig_ripple_1d, - }.get(sig_type) - - return ( - ( - np.array( - [ - fn_1d( - t_sec=t_sec, - fs=fs, - freqs_hz=freqs_hz, - verbose=verbose, - ) - for _ in range(int(batch_size * n_chs)) - ] - ) - .reshape(batch_size, n_chs, -1) - .astype(np.float32)[..., : len(tt)] - ), - tt, - fs, - ) - - -def _demo_sig_pac( - batch_size=8, - n_chs=19, - t_sec=4, - fs=512, - f_pha=10, - f_amp=100, - noise=0.8, - n_segments=20, - verbose=False, -): - """ - Generate a demo signal with phase-amplitude coupling. - - Parameters - ---------- - batch_size (int): Number of batches. - n_chs (int): Number of channels. - t_sec (int): Duration of the signal in seconds. - fs (int): Sampling frequency. - f_pha (float): Frequency of the phase-modulating signal. - f_amp (float): Frequency of the amplitude-modulated signal. - noise (float): Noise level added to the signal. - n_segments (int): Number of segments. - verbose (bool): If True, print additional information. - - Returns - ------- - np.array: Generated signals with shape (batch_size, n_chs, n_segments, seq_len). - """ - seq_len = t_sec * fs - t = np.arange(seq_len) / fs - if verbose: - print(f"Generating signal with length: {seq_len}") - - # Create empty array to store the signals - signals = np.zeros((batch_size, n_chs, n_segments, seq_len)) - - for b in range(batch_size): - for ch in range(n_chs): - for seg in range(n_segments): - # Phase signal - theta = np.sin(2 * np.pi * f_pha * t) - # Amplitude envelope - amplitude_env = 1 + np.sin(2 * np.pi * f_amp * t) - # Combine phase and amplitude modulation - signal = theta * amplitude_env - # Add Gaussian noise - signal += noise * np.random.randn(seq_len) - signals[b, ch, seg, :] = signal - - return signals - - -def _demo_sig_tensorpac( - batch_size=8, - n_chs=19, - t_sec=4, - fs=512, - f_pha=10, - f_amp=100, - noise=0.8, - n_segments=20, - verbose=False, -): - _check_tensorpac() - n_times = int(t_sec * fs) - x_2d, tt = pac_signals_wavelet( - sf=fs, - f_pha=f_pha, - f_amp=f_amp, - noise=noise, - n_epochs=n_segments, - n_times=n_times, - ) - x_3d = np.stack([x_2d for _ in range(batch_size)], axis=0) - x_4d = np.stack([x_3d for _ in range(n_chs)], axis=1) - return x_4d, tt - - -def _demo_sig_meg(batch_size=8, n_chs=19, t_sec=10, fs=512, verbose=False, **kwargs): - _check_mne() - data_path = sample.data_path() - meg_path = data_path / "MEG" / "sample" - raw_fname = meg_path / "sample_audvis_raw.fif" - fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif" - - # Load real data as the template - raw = mne.io.read_raw_fif(raw_fname, verbose=verbose) - raw = raw.crop(tmax=t_sec, verbose=verbose) - raw = raw.resample(fs, verbose=verbose) - raw.set_eeg_reference(projection=True, verbose=verbose) - - return raw.get_data( - picks=raw.ch_names[: batch_size * n_chs], verbose=verbose - ).reshape(batch_size, n_chs, -1) - - -def _demo_sig_periodic_1d(t_sec=10, fs=512, freqs_hz=None, verbose=False, **kwargs): - """Returns a demo signal with the shape (t_sec*fs,).""" - if freqs_hz is None: - n_freqs = random.randint(1, 5) - freqs_hz = np.random.permutation(np.arange(fs))[:n_freqs] - if verbose: - print(f"freqs_hz was randomly determined as {freqs_hz}") - - n = int(t_sec * fs) - t = np.linspace(0, t_sec, n, endpoint=False) - - summed = np.array( - [ - np.random.rand() * np.sin((f_hz * t + np.random.rand()) * (2 * np.pi)) - for f_hz in freqs_hz - ] - ).sum(axis=0) - return summed - - -def _demo_sig_chirp_1d( - t_sec=10, fs=512, low_hz=None, high_hz=None, verbose=False, **kwargs -): - if low_hz is None: - low_hz = random.randint(1, 20) - if verbose: - warnings.warn(f"low_hz was randomly determined as {low_hz}.") - - if high_hz is None: - high_hz = random.randint(100, 1000) - if verbose: - warnings.warn(f"high_hz was randomly determined as {high_hz}.") - - n = int(t_sec * fs) - t = np.linspace(0, t_sec, n, endpoint=False) - x = chirp(t, low_hz, t[-1], high_hz) - x *= 1.0 + 0.5 * np.sin(2.0 * np.pi * 3.0 * t) - return x - - -def _demo_sig_ripple_1d(t_sec=10, fs=512, **kwargs): - _check_ripple_detection() - n_samples = t_sec * fs - t = simulate_time(n_samples, fs) - n_ripples = random.randint(1, 5) - mid_time = np.random.permutation(t)[:n_ripples] - return simulate_LFP(t, mid_time, noise_amplitude=1.2, ripple_amplitude=5) - - -if __name__ == "__main__": - import scitex - - # Start - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) - import scitex - - SIG_TYPES = [ - "uniform", - "gauss", - "periodic", - "chirp", - "meg", - "ripple", - "tensorpac", - "pac", - ] - - i_batch, i_ch, i_segment = 0, 0, 0 - fig, axes = scitex.plt.subplots(nrows=len(SIG_TYPES)) - for ax, (i_sig_type, sig_type) in zip(axes, enumerate(SIG_TYPES)): - xx, tt, fs = demo_sig(sig_type=sig_type) - if sig_type not in ["tensorpac", "pac"]: - ax.plot(tt, xx[i_batch, i_ch], label=sig_type) - else: - ax.plot(tt, xx[i_batch, i_ch, i_segment], label=sig_type) - ax.legend(loc="upper left") - fig.suptitle("Demo signals") - fig.supxlabel("Time [s]") - fig.supylabel("Amplitude [?V]") - scitex.io.save(fig, "traces.png") - - # Close - scitex.session.close(CONFIG) - -# EOF - -""" -/home/ywatanabe/proj/entrance/scitex/dsp/_demo_sig.py -""" - -# EOF diff --git a/src/scitex/dsp/_detect_ripples.py b/src/scitex/dsp/_detect_ripples.py deleted file mode 100755 index 0e1a3bed5..000000000 --- a/src/scitex/dsp/_detect_ripples.py +++ /dev/null @@ -1,216 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-11-05 00:24:54 (ywatanabe)" -# File: ./scitex_repo/src/scitex/dsp/_detect_ripples.py - -import numpy as np -import pandas as pd -from scipy.signal import find_peaks - -from scitex.gen._norm import to_z - -from ._demo_sig import demo_sig -from ._hilbert import hilbert -from ._resample import resample -from .filt import bandpass, gauss - - -def detect_ripples( - xx, - fs, - low_hz, - high_hz, - sd=2.0, - smoothing_sigma_ms=4, - min_duration_ms=10, - return_preprocessed_signal=False, -): - """ - xx: 2-dimensional (n_chs, seq_len) or 3-dimensional (batch_size, n_chs, seq_len) wide-band signal. - """ - - try: - xx_r, fs_r = _preprocess(xx, fs, low_hz, high_hz, smoothing_sigma_ms) - df = _find_events(xx_r, fs_r, sd, min_duration_ms) - df = _drop_ripples_at_edges(df, low_hz, xx_r, fs_r) - df = _calc_relative_peak_position(df) - # df = _calc_incidence(df, xx_r, fs_r) - df = _sort_columns(df) - - if not return_preprocessed_signal: - return df - - elif return_preprocessed_signal: - return df, xx_r, fs_r - - except ValueError as e: - print("Caught an error:", e) - - -def _preprocess(xx, fs, low_hz, high_hz, smoothing_sigma_ms=4): - # Ensures three dimensional - if xx.ndim == 2: - xx = xx[np.newaxis] - assert xx.ndim == 3 - - # For readability - RIPPLE_BANDS = np.vstack([[low_hz, high_hz]]) - - # Downsampling - fs_tgt = low_hz * 3 - xx = resample(xx, float(fs), float(fs_tgt)) - fs = fs_tgt - - # Subtracts the global mean to reduce false detection due to EMG signal - xx -= np.nanmean(xx, axis=1, keepdims=True) - - # Bandpass Filtering - xx = ( - ( - bandpass( - np.array(xx), - fs_tgt, - RIPPLE_BANDS, - ) - ) - .squeeze(-2) - .astype(np.float64) - ) - - # Calculate RMS - xx = xx**2 - _, xx = hilbert(xx) - xx = gauss(xx, smoothing_sigma_ms * 1e-3 * fs_tgt).squeeze(-2) - xx = np.sqrt(xx) - - # Scales across channels - xx = xx.mean(axis=1) - xx = to_z(xx, dim=-1) - - return xx, fs_tgt - - -def _find_events(xx_r, fs_r, sd, min_duration_ms): - def _find_events_1d(xx_ri, fs_r, sd, min_duration_ms): - # Finds peaks over the designated standard deviation - peaks, properties = find_peaks(xx_ri, height=sd) - - # Determines the range around each peak (customize as needed) - peaks_all = [] - peak_ranges = [] - peak_amplitudes_sd = [] - - for peak in peaks: - left_bound = np.where(xx_ri[:peak] < 0)[0] - right_bound = np.where(xx_ri[peak:] < 0)[0] - - left_ips = left_bound.max() if left_bound.size > 0 else peak - right_ips = peak + right_bound.min() if right_bound.size > 0 else peak - - # Avoid duplicates: Check if the current peak range is already listed - if not any( - (left_ips == start and right_ips == end) for start, end in peak_ranges - ): - peaks_all.append(peak) - peak_ranges.append((left_ips, right_ips)) - peak_amplitudes_sd.append(xx_ri[peak]) - - # Converts to DataFrame - if peak_ranges: - starts, ends = zip(*peak_ranges) if peak_ranges else ([], []) - df = pd.DataFrame( - { - "start_s": np.hstack(starts) / fs_r, - "peak_s": np.hstack(peaks_all) / fs_r, - "end_s": np.hstack(ends) / fs_r, - "peak_amp_sd": np.hstack(peak_amplitudes_sd), - } - ).round(3) - else: - df = pd.DataFrame(columns=["start_s", "peak_s", "end_s", "peak_amp_sd"]) - - # Duration - df["duration_s"] = df.end_s - df.start_s - - # Filters events with short duration - df = df[df.duration_s > (min_duration_ms * 1e-3)] - - return df - - if xx_r.ndim == 1: - xx_r = xx_r[np.newaxis, :] - assert xx_r.ndim == 2 - - dfs = [] - for i_ch in range(len(xx_r)): - xx_ri = xx_r[i_ch] - df_i = _find_events_1d(xx_ri, fs_r, sd, min_duration_ms) - df_i.index = [i_ch for _ in range(len(df_i))] - dfs.append(df_i) - dfs = pd.concat(dfs) - - return dfs - - -def _drop_ripples_at_edges(df, low_hz, xx_r, fs_r): - edge_s = 1 / low_hz * 3 - indi_drop = (df.start_s < edge_s) + (xx_r.shape[-1] / fs_r - edge_s < df.end_s) - df = df[~indi_drop] - return df - - -def _calc_relative_peak_position(df): - delta_s = df.peak_s - df.start_s - rel_peak = delta_s / df.duration_s - df["rel_peak_pos"] = np.round(rel_peak, 3) - return df - - -# def _calc_incidence(df, xx_r, fs_r): -# n_ripples = len(df) -# rec_s = xx_r.shape[-1] / fs_r -# df["incidence_hz"] = n_ripples / rec_s -# return df - - -def _sort_columns(df): - sorted_columns = [ - "start_s", - "end_s", - "duration_s", - "peak_s", - "rel_peak_pos", - "peak_amp_sd", - # "incidence_hz", - ] - df = df[sorted_columns] - return df - - -def main(): - xx, tt, fs = demo_sig(sig_type="ripple") - df = detect_ripples(xx, fs, 80, 140) - print(df) - - -if __name__ == "__main__": - import sys - - import matplotlib.pyplot as plt - - import scitex - - # # Argument Parser - # import argparse - # parser = argparse.ArgumentParser(description='') - # parser.add_argument('--var', '-v', type=int, default=1, help='') - # parser.add_argument('--flag', '-f', action='store_true', default=False, help='') - # args = parser.parse_args() - # Main - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start( - sys, plt, verbose=False - ) - main() - scitex.session.close(CONFIG, verbose=False, notify=False) - -# EOF diff --git a/src/scitex/dsp/_ensure_3d.py b/src/scitex/dsp/_ensure_3d.py deleted file mode 100755 index 3fba15efd..000000000 --- a/src/scitex/dsp/_ensure_3d.py +++ /dev/null @@ -1,18 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-11-05 01:03:47 (ywatanabe)" -# File: ./scitex_repo/src/scitex/dsp/_ensure_3d.py - -from scitex.decorators import signal_fn - - -@signal_fn -def ensure_3d(x): - if x.ndim == 1: # assumes (seq_len,) - x = x.unsqueeze(0).unsqueeze(0) - elif x.ndim == 2: # assumes (batch_siize, seq_len) - x = x.unsqueeze(1) - return x - - -# EOF diff --git a/src/scitex/dsp/_hilbert.py b/src/scitex/dsp/_hilbert.py deleted file mode 100755 index 181083411..000000000 --- a/src/scitex/dsp/_hilbert.py +++ /dev/null @@ -1,78 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-11-04 02:07:11 (ywatanabe)" -# File: ./scitex_repo/src/scitex/dsp/_hilbert.py - -""" -This script does XYZ. -""" - -import sys - -import matplotlib.pyplot as plt - -from scitex.decorators import signal_fn -from scitex.nn._Hilbert import Hilbert - - -# Functions -@signal_fn -def hilbert( - x, - dim=-1, -): - y = Hilbert(x.shape[-1], dim=dim)(x) - return y[..., 0], y[..., 1] - - -if __name__ == "__main__": - import scitex - - # Start - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) - - # Parameters - T_SEC = 1.0 - FS = 400 - SIG_TYPE = "chirp" - - # Demo signal - xx, tt, fs = scitex.dsp.demo_sig(t_sec=T_SEC, fs=FS, sig_type=SIG_TYPE) - - # Main - pha, amp = hilbert( - xx, - dim=-1, - ) - # (32, 19, 1280, 2) - - # Plots - fig, axes = scitex.plt.subplots(nrows=2, sharex=True) - fig.suptitle("Hilbert Transformation") - - axes[0].plot(tt, xx[0, 0], label=SIG_TYPE) - axes[0].plot(tt, amp[0, 0], label="Amplidue") - axes[0].legend() - # axes[0].set_xlabel("Time [s]") - axes[0].set_ylabel("Amplitude [?V]") - - axes[1].plot(tt, pha[0, 0], label="Phase") - axes[1].legend() - - axes[1].set_xlabel("Time [s]") - axes[1].set_ylabel("Phase [rad]") - - # plt.show() - scitex.io.save(fig, "traces.png") - - # Close - scitex.session.close(CONFIG) - -# EOF - -""" -/home/ywatanabe/proj/entrance/scitex/dsp/_hilbert.py -""" - - -# EOF diff --git a/src/scitex/dsp/_listen.py b/src/scitex/dsp/_listen.py deleted file mode 100755 index b7bf9bd50..000000000 --- a/src/scitex/dsp/_listen.py +++ /dev/null @@ -1,702 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-11-08 09:15:59 (ywatanabe)" -# File: ./scitex_repo/src/scitex/dsp/_listen.py - -import os -import sys - -import matplotlib.pyplot as plt -import sounddevice as sd - -os.environ["PULSE_SERVER"] = "unix:/mnt/wslg/PulseServer" - -# # WSL2 Sound Support -# export PULSE_SERVER=unix:/mnt/wslg/PulseServer - - -def list_and_select_device() -> int: - """ - List available audio devices and prompt user to select one. - - Example - ------- - >>> device_id = list_and_select_device() - Available audio devices: - ... - Enter the ID of the device you want to use: - - Returns - ------- - int - Selected device ID - """ - try: - print("Available audio devices:") - devices = sd.query_devices() - print(devices) - device_id = int(input("Enter the ID of the device you want to use: ")) - if device_id not in range(len(devices)): - raise ValueError(f"Invalid device ID: {device_id}") - return device_id - except (ValueError, sd.PortAudioError) as err: - print(f"Error during device selection: {err}") - return 0 - - -if __name__ == "__main__": - import scitex - - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) - - signal, time_points, sampling_freq = scitex.dsp.demo_sig("chirp") - - device_id = list_and_select_device() - sd.default.device = device_id - - listen(signal, sampling_freq) - - scitex.session.close(CONFIG) - -# def play_audio( -# samples: np.ndarray, fs: int = 44100, channels: int = 1 -# ) -> None: -# """Play audio using PyAudio""" -# print("Initializing PyAudio...") -# p = pyaudio.PyAudio() - -# # List available devices -# print("\nAvailable audio devices:") -# for i in range(p.get_device_count()): -# dev = p.get_device_info_by_index(i) -# print(f"Device {i}: {dev['name']}") - -# try: -# # Rest of the code remains the same -# if samples.dtype != np.float32: -# print(f"Converting from {samples.dtype} to float32...") -# samples = samples.astype(np.float32) - -# if len(samples) == 0: -# print("No input samples, creating test tone...") -# duration = 1 -# t = np.linspace(0, duration, int(fs * duration)) -# samples = np.sin(2 * np.pi * 440 * t) - -# print(f"Opening audio stream (fs={fs}Hz, channels={channels})...") -# stream = p.open( -# format=pyaudio.paFloat32, channels=channels, rate=fs, output=True -# ) - -# print("Playing audio...") -# stream.write(samples.tobytes()) - -# except Exception as e: -# print(f"Error: {e}") -# finally: -# print("Cleaning up...") -# if "stream" in locals(): -# stream.stop_stream() -# stream.close() -# p.terminate() -# print("Done.") - -# if __name__ == "__main__": -# # Test with a simple sine wave -# duration = 2 -# fs = 44100 -# t = np.linspace(0, duration, int(fs * duration)) -# test_signal = np.sin(2 * np.pi * 440 * t) # 440 Hz -# play_audio(test_signal, fs) - -# # def record_audio( -# # duration: float = 5.0, fs: int = 44100, channels: int = 1 -# # ) -> np.ndarray: -# # """Record audio using PyAudio with WSL2 compatibility""" -# # chunk = 1024 -# # format = pyaudio.paFloat32 - -# # p = pyaudio.PyAudio() - -# # # List available devices -# # for i in range(p.get_device_count()): -# # print(p.get_device_info_by_index(i)) - -# # try: -# # stream = p.open( -# # format=format, -# # channels=channels, -# # rate=fs, -# # input=True, -# # frames_per_buffer=chunk, -# # input_device_index=None, # Use default device -# # ) - -# # frames = [] -# # print("Recording...") - -# # for _ in range(0, int(fs / chunk * duration)): -# # data = stream.read(chunk) -# # frames.append(data) - -# # print("Done recording") -# # return np.frombuffer(b"".join(frames), dtype=np.float32) - -# # finally: -# # if "stream" in locals(): -# # stream.stop_stream() -# # stream.close() -# # p.terminate() - -# # def record_audio( -# # duration: float = 5.0, fs: int = 44100, channels: int = 1 -# # ) -> np.ndarray: -# # """ -# # Record audio using PyAudio. -# # """ -# # chunk = 1024 -# # format = pyaudio.paFloat32 - -# # p = pyaudio.PyAudio() - -# # stream = p.open( -# # format=format, -# # channels=channels, -# # rate=fs, -# # input=True, -# # frames_per_buffer=chunk, -# # ) - -# # frames = [] - -# # for _ in range(0, int(fs / chunk * duration)): -# # data = stream.read(chunk) -# # frames.append(data) - -# # stream.stop_stream() -# # stream.close() -# # p.terminate() - -# # return np.frombuffer(b"".join(frames), dtype=np.float32) - -# # if __name__ == "__main__": -# # # Test recording -# # audio = record_audio(duration=3.0) -# # print(f"Recorded {len(audio)} samples") - -# # # #!/usr/bin/env python3 -# # # # -*- coding: utf-8 -*- -# # # # Time-stamp: "2024-11-08 08:56:12 (ywatanabe)" -# # # # File: ./scitex_repo/src/scitex/dsp/_listen.py - -# # # import os -# # # import sys -# # # from typing import Any, Dict, Literal, Tuple - -# # # import matplotlib.pyplot as plt -# # # import scitex -# # # import numpy as np -# # # from scipy.signal import resample - -# # # os.environ["DISPLAY"] = ":0" # Set a default display - -# # # # Avoid GUI/clipboard dependencies -# # # def dummy_clipboard_get(): -# # # raise NotImplementedError( -# # # "Clipboard not available in headless environment" -# # # ) - -# # # try: -# # # from IPython import get_ipython - -# # # ipython = get_ipython() -# # # if ipython is not None: -# # # ipython.hooks.clipboard_get = dummy_clipboard_get -# # # except ImportError: -# # # pass - -# # # """ -# # # Functionality: -# # # - Provides audio playback and signal sonification functionality -# # # - Supports multiple sonification methods (frequency shift, AM/FM Modulation) -# # # - Includes device selection and audio information display utilities -# # # Input: -# # # - Multichannel signal arrays (numpy.ndarray) -# # # - Sampling frequency and sonification parameters -# # # Output: -# # # - Audio playback through specified output device -# # # Prerequisites: -# # # - PortAudio library (install with: sudo apt-get install portaudio19-dev) -# # # - sounddevice package -# # # """ - -# # # """Imports""" -# # # """Config""" -# # # CONFIG = scitex.gen.load_configs() - -# # # """Functions""" - -# # # def frequency_shift( -# # # signal: np.ndarray, -# # # shift_factor: int = 200, -# # # ) -> np.ndarray: -# # # """ -# # # Shift signal frequencies by resampling. - -# # # Example -# # # ------- -# # # >>> signal = np.sin(2 * np.pi * 10 * np.linspace(0, 1, 1000)) -# # # >>> shifted = frequency_shift(signal, shift_factor=200) - -# # # Parameters -# # # ---------- -# # # signal : np.ndarray -# # # Input signal to be frequency shifted -# # # shift_factor : int -# # # Frequency multiplication factor - -# # # Returns -# # # ------- -# # # np.ndarray -# # # Frequency shifted signal -# # # """ -# # # num_samples = int(len(signal) * shift_factor) -# # # return resample(signal, num_samples) - -# # # def am_Modulation( -# # # signal: np.ndarray, carrier_freq: float = 440, fs: int = 44_100 -# # # ) -> np.ndarray: -# # # """ -# # # Perform amplitude Modulation on input signal. - -# # # Example -# # # ------- -# # # >>> signal = np.sin(2 * np.pi * 10 * np.linspace(0, 1, 1000)) -# # # >>> modulated = am_Modulation(signal, carrier_freq=440, fs=44100) - -# # # Parameters -# # # ---------- -# # # signal : np.ndarray -# # # Input signal to modulate -# # # carrier_freq : float -# # # Carrier frequency in Hz -# # # fs : int -# # # Sampling frequency in Hz - -# # # Returns -# # # ------- -# # # np.ndarray -# # # Amplitude modulated signal -# # # """ -# # # t = np.arange(len(signal)) / fs -# # # carrier = np.sin(2 * np.pi * carrier_freq * t) -# # # return (1 + signal) * carrier - -# # # def fm_Modulation( -# # # signal: np.ndarray, -# # # carrier_freq: float = 440, -# # # sensitivity: float = 0.5, -# # # fs: int = 44_100, -# # # ) -> np.ndarray: -# # # """ -# # # Perform frequency Modulation on input signal. - -# # # Example -# # # ------- -# # # >>> signal = np.sin(2 * np.pi * 10 * np.linspace(0, 1, 1000)) -# # # >>> modulated = fm_Modulation(signal, carrier_freq=440, sensitivity=0.5, fs=44100) - -# # # Parameters -# # # ---------- -# # # signal : np.ndarray -# # # Input signal to modulate -# # # carrier_freq : float -# # # Carrier frequency in Hz -# # # sensitivity : float -# # # Frequency sensitivity factor -# # # fs : int -# # # Sampling frequency in Hz - -# # # Returns -# # # ------- -# # # np.ndarray -# # # Frequency modulated signal -# # # """ -# # # t = np.arange(len(signal)) / fs -# # # phase = 2 * np.pi * carrier_freq * t + sensitivity * np.cumsum(signal) / fs -# # # return np.sin(phase) - -# # # def sonify_eeg( -# # # signal_array: np.ndarray, -# # # sampling_freq: int, -# # # method: Literal["shift", "am", "fm"] = "shift", -# # # channels: Tuple[int, ...] = (0, 1), -# # # target_fs: int = 44_100, -# # # **kwargs: Dict[str, Any], -# # # ) -> None: -# # # """ -# # # Convert EEG signal to audio using various sonification methods. - -# # # Example -# # # ------- -# # # >>> eeg_data = np.random.randn(1, 32, 1000) # Mock EEG data -# # # >>> sonify_eeg(eeg_data, 250, method='fm', channels=(0,1)) - -# # # Parameters -# # # ---------- -# # # signal_array : np.ndarray -# # # EEG signal array of shape (batch_size, n_channels, sequence_length) -# # # sampling_freq : int -# # # Original sampling frequency of the EEG signal -# # # method : {'shift', 'am', 'fm'} -# # # Sonification method to use -# # # channels : Tuple[int, ...] -# # # Channels to include in sonification -# # # target_fs : int -# # # Target audio sampling frequency -# # # **kwargs : Dict[str, Any] -# # # Additional parameters for specific sonification methods - -# # # Returns -# # # ------- -# # # None -# # # """ - -# # # if not isinstance(signal_array, np.ndarray): -# # # signal_array = np.array(signal_array) - -# # # if len(signal_array.shape) != 3: -# # # raise ValueError(f"Expected 3D array, got shape {signal_array.shape}") - -# # # if max(channels) >= signal_array.shape[1]: -# # # raise ValueError( -# # # f"Channel index {max(channels)} out of range (max: {signal_array.shape[1]-1})" -# # # ) - -# # # selected_channels = signal_array[:, channels, :].mean(axis=1) -# # # signal = selected_channels.mean(axis=0) - -# # # # Normalize -# # # signal = signal / np.max(np.abs(signal)) - -# # # # Apply selected method -# # # if method == "shift": -# # # audio = frequency_shift(signal, kwargs.get("shift_factor", 200)) -# # # elif method == "am": -# # # audio = am_Modulation( -# # # signal, kwargs.get("carrier_freq", 440), target_fs -# # # ) -# # # elif method == "fm": -# # # audio = fm_Modulation( -# # # signal, -# # # kwargs.get("carrier_freq", 440), -# # # kwargs.get("sensitivity", 0.5), -# # # target_fs, -# # # ) -# # # else: -# # # raise ValueError(f"Unknown method: {method}") - -# # # if len(audio) < 100: -# # # raise ValueError("Audio signal too short after processing") - -# # # sd.play(audio, target_fs) -# # # sd.wait() - -# # # def listen( -# # # signal_array: np.ndarray, -# # # sampling_freq: int, -# # # channels: Tuple[int, ...] = (0, 1), -# # # target_fs: int = 44_100, -# # # ) -> None: -# # # """ -# # # Play selected channels of a multichannel signal array as audio. - -# # # Example -# # # ------- -# # # >>> signal = np.random.randn(1, 2, 1000) # Random stereo signal -# # # >>> listen(signal, 16000, channels=(0, 1)) - -# # # Parameters -# # # ---------- -# # # signal_array : np.ndarray -# # # Signal array of shape (batch_size, n_channels, sequence_length) -# # # sampling_freq : int -# # # Original sampling frequency of the signal -# # # channels : Tuple[int, ...] -# # # Tuple of channel indices to listen to -# # # target_fs : int -# # # Target sampling frequency for playback - -# # # Returns -# # # ------- -# # # None -# # # """ -# # # if not isinstance(signal_array, np.ndarray): -# # # signal_array = np.array(signal_array) - -# # # if len(signal_array.shape) != 3: -# # # raise ValueError(f"Expected 3D array, got shape {signal_array.shape}") - -# # # if max(channels) >= signal_array.shape[1]: -# # # raise ValueError( -# # # f"Channel index {max(channels)} out of range (max: {signal_array.shape[1]-1})" -# # # ) - -# # # selected_channels = signal_array[:, channels, :].mean(axis=1) -# # # audio_signal = selected_channels.mean(axis=0) - -# # # if sampling_freq != target_fs: -# # # num_samples = int(round(len(audio_signal) * target_fs / sampling_freq)) -# # # audio_signal = resample(audio_signal, num_samples) - -# # # sd.play(audio_signal, target_fs) -# # # sd.wait() - -# # # def print_device_info() -> None: -# # # """ -# # # Display information about the default audio output device. - -# # # Example -# # # ------- -# # # >>> print_device_info() -# # # Default Output Device Info: -# # # -# # # """ -# # # try: -# # # device_info = sd.query_devices(kind="output") -# # # print(f"Default Output Device Info: \n{device_info}") -# # # except sd.PortAudioError as err: -# # # print(f"Error querying audio devices: {err}") - -# # # if __name__ == "__main__": -# # # import scitex - -# # # CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) - -# # # # Generate a test signal if demo_sig fails -# # # try: -# # # signal, time_points, sampling_freq = scitex.dsp.demo_sig("chirp") -# # # except Exception as err: -# # # print(f"Failed to load demo signal: {err}") -# # # # Generate a simple test signal -# # # duration = 2 # seconds -# # # sampling_freq = 1000 # Hz -# # # t = np.linspace(0, duration, int(duration * sampling_freq)) -# # # test_signal = np.sin(2 * np.pi * 10 * t) # 10 Hz sine wave -# # # signal = test_signal.reshape(1, 1, -1) - -# # # # Try to get audio device -# # # try: -# # # device_id = list_and_select_device() -# # # sd.default.device = device_id -# # # except Exception as err: -# # # print(f"Failed to set audio device: {err}") -# # # print("Using default audio device") -# # # device_id = None - -# # # # Test different sonification methods with error handling -# # # methods = ["shift", "am", "fm"] -# # # for method in methods: -# # # try: -# # # print(f"\nTesting {method} sonification...") -# # # sonify_eeg(signal, sampling_freq, method=method) -# # # except Exception as err: -# # # print(f"Failed to play {method} sonification: {err}") - -# # # scitex.session.close(CONFIG) - -# # # # EOF - -# # # # #!/usr/bin/env python3 -# # # # # -*- coding: utf-8 -*- -# # # # # Time-stamp: "2024-11-07 18:58:37 (ywatanabe)" -# # # # # File: ./scitex_repo/src/scitex/dsp/_listen.py - -# # # # import sys -# # # # from typing import Tuple - -# # # # import matplotlib.pyplot as plt -# # # # import scitex -# # # # import numpy as np -# # # # import sounddevice as sd -# # # # from scipy.signal import resample - -# # # # """ -# # # # Functionality: -# # # # - Provides audio playback functionality for multichannel signal arrays -# # # # - Includes device selection and audio information display utilities -# # # # Input: -# # # # - Multichannel signal arrays (numpy.ndarray) -# # # # - Sampling frequency and channel selection -# # # # Output: -# # # # - Audio playback through specified output device -# # # # Prerequisites: -# # # # - PortAudio library (install with: sudo apt-get install portaudio19-dev) -# # # # - sounddevice package -# # # # """ - -# # # # """Imports""" -# # # # """Config""" -# # # # CONFIG = scitex.gen.load_configs() - -# # # # """Functions""" -# # # # def listen( -# # # # signal_array: np.ndarray, -# # # # sampling_freq: int, -# # # # channels: Tuple[int, ...] = (0, 1), -# # # # target_fs: int = 44_100, -# # # # ) -> None: -# # # # """ -# # # # Play selected channels of a multichannel signal array as audio. - -# # # # Example -# # # # ------- -# # # # >>> signal = np.random.randn(1, 2, 1000) # Random stereo signal -# # # # >>> listen(signal, 16000, channels=(0, 1)) - -# # # # Parameters -# # # # ---------- -# # # # signal_array : np.ndarray -# # # # Signal array of shape (batch_size, n_channels, sequence_length) -# # # # sampling_freq : int -# # # # Original sampling frequency of the signal -# # # # channels : Tuple[int, ...] -# # # # Tuple of channel indices to listen to -# # # # target_fs : int -# # # # Target sampling frequency for playback - -# # # # Returns -# # # # ------- -# # # # None -# # # # """ -# # # # if not isinstance(signal_array, np.ndarray): -# # # # signal_array = np.array(signal_array) - -# # # # if len(signal_array.shape) != 3: -# # # # raise ValueError(f"Expected 3D array, got shape {signal_array.shape}") - -# # # # if max(channels) >= signal_array.shape[1]: -# # # # raise ValueError(f"Channel index {max(channels)} out of range (max: {signal_array.shape[1]-1})") - -# # # # selected_channels = signal_array[:, channels, :].mean(axis=1) -# # # # audio_signal = selected_channels.mean(axis=0) - -# # # # if sampling_freq != target_fs: -# # # # num_samples = int(round(len(audio_signal) * target_fs / sampling_freq)) -# # # # audio_signal = resample(audio_signal, num_samples) - -# # # # sd.play(audio_signal, target_fs) -# # # # sd.wait() - -# # # # def print_device_info() -> None: -# # # # """ -# # # # Display information about the default audio output device. - -# # # # Example -# # # # ------- -# # # # >>> print_device_info() -# # # # Default Output Device Info: -# # # # -# # # # """ -# # # # try: -# # # # device_info = sd.query_devices(kind="output") -# # # # print(f"Default Output Device Info: \n{device_info}") -# # # # except sd.PortAudioError as err: -# # # # print(f"Error querying audio devices: {err}") - -# # # # def list_and_select_device() -> int: -# # # # """ -# # # # List available audio devices and prompt user to select one. - -# # # # Example -# # # # ------- -# # # # >>> device_id = list_and_select_device() -# # # # Available audio devices: -# # # # ... -# # # # Enter the ID of the device you want to use: - -# # # # Returns -# # # # ------- -# # # # int -# # # # Selected device ID -# # # # """ -# # # # try: -# # # # print("Available audio devices:") -# # # # devices = sd.query_devices() -# # # # print(devices) -# # # # device_id = int(input("Enter the ID of the device you want to use: ")) -# # # # if device_id not in range(len(devices)): -# # # # raise ValueError(f"Invalid device ID: {device_id}") -# # # # return device_id -# # # # except (ValueError, sd.PortAudioError) as err: -# # # # print(f"Error during device selection: {err}") -# # # # return 0 - -# # # # def frequency_shift(signal: np.ndarray, shift_factor: int = 200) -> np.ndarray: -# # # # """Direct frequency shifting""" -# # # # num_samples = int(len(signal) * shift_factor) -# # # # return resample(signal, num_samples) - -# # # # def am_Modulation(signal: np.ndarray, carrier_freq: float = 440, fs: int = 44100) -> np.ndarray: -# # # # """Amplitude Modulation""" -# # # # t = np.arange(len(signal)) / fs -# # # # carrier = np.sin(2 * np.pi * carrier_freq * t) -# # # # return (1 + signal) * carrier - -# # # # def fm_Modulation(signal: np.ndarray, carrier_freq: float = 440, sens: float = 0.5, fs: int = 44100) -> np.ndarray: -# # # # """Frequency Modulation""" -# # # # t = np.arange(len(signal)) / fs -# # # # phase = 2 * np.pi * carrier_freq * t + sens * np.cumsum(signal) / fs -# # # # return np.sin(phase) - -# # # # def sonify_eeg( -# # # # signal_array: np.ndarray, -# # # # sampling_freq: int, -# # # # method: str = 'shift', -# # # # channels: Tuple[int, ...] = (0, 1), -# # # # target_fs: int = 44100, -# # # # **kwargs -# # # # ) -> None: -# # # # """Main sonification function""" -# # # # selected_channels = signal_array[:, channels, :].mean(axis=1) -# # # # signal = selected_channels.mean(axis=0) - -# # # # # Normalize -# # # # signal = signal / np.max(np.abs(signal)) - -# # # # # Apply selected method -# # # # if method == 'shift': -# # # # audio = frequency_shift(signal, kwargs.get('shift_factor', 200)) -# # # # elif method == 'am': -# # # # audio = am_Modulation(signal, kwargs.get('carrier_freq', 440), target_fs) -# # # # elif method == 'fm': -# # # # audio = fm_Modulation(signal, kwargs.get('carrier_freq', 440), kwargs.get('sensitivity', 0.5), target_fs) -# # # # else: -# # # # raise ValueError(f"Unknown method: {method}") - -# # # # sd.play(audio, target_fs) -# # # # sd.wait() - -# # # # if __name__ == "__main__": -# # # # import scitex - -# # # # CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) - -# # # # signal, time_points, sampling_freq = scitex.dsp.demo_sig("chirp") - -# # # # device_id = list_and_select_device() -# # # # sd.default.device = device_id - -# # # # listen(signal, sampling_freq) - -# # # # scitex.session.close(CONFIG) - -# # # # - -# # # - -# # - -# - -# EOF diff --git a/src/scitex/dsp/_misc.py b/src/scitex/dsp/_misc.py deleted file mode 100755 index 0437ffe41..000000000 --- a/src/scitex/dsp/_misc.py +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-11-05 01:03:32 (ywatanabe)" -# File: ./scitex_repo/src/scitex/dsp/_misc.py - -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-04-05 12:14:08 (ywatanabe)" - -from scitex.decorators import torch_fn - - -@torch_fn -def ensure_3d(x): - if x.ndim == 1: # assumes (seq_len,) - x = x.unsqueeze(0).unsqueeze(0) - elif x.ndim == 2: # assumes (batch_siize, seq_len) - x = x.unsqueeze(1) - return x - - -# @torch_fn -# def unbias(x, dim=-1, fn="mean"): -# if fn == "mean": -# return x - x.mean(dim=dim, keepdims=True) -# if fn == "min": -# return x - x.min(dim=dim, keepdims=True)[0] - - -# EOF diff --git a/src/scitex/dsp/_mne.py b/src/scitex/dsp/_mne.py deleted file mode 100755 index 36f3451b4..000000000 --- a/src/scitex/dsp/_mne.py +++ /dev/null @@ -1,43 +0,0 @@ -#!/usr/bin/env python3 -# Time-stamp: "2024-11-04 02:07:36 (ywatanabe)" -# File: ./scitex_repo/src/scitex/dsp/_mne.py - -try: - import mne - - MNE_AVAILABLE = True -except ImportError: - MNE_AVAILABLE = False - mne = None - -import pandas as pd - -from .params import EEG_MONTAGE_1020 - - -def get_eeg_pos(channel_names=EEG_MONTAGE_1020): - if not MNE_AVAILABLE: - raise ImportError( - "MNE-Python is not installed. Please install with: pip install mne" - ) - # Load the standard 10-20 montage - standard_montage = mne.channels.make_standard_montage("standard_1020") - standard_montage.ch_names = [ - ch_name.upper() for ch_name in standard_montage.ch_names - ] - - # Get the positions of the electrodes in the standard montage - positions = standard_montage.get_positions() - - df = pd.DataFrame(positions["ch_pos"])[channel_names] - - df.set_index(pd.Series(["x", "y", "z"])) - - return df - - -if __name__ == "__main__": - print(get_eeg_pos()) - - -# EOF diff --git a/src/scitex/dsp/_modulation_index.py b/src/scitex/dsp/_modulation_index.py deleted file mode 100755 index 6a075a9df..000000000 --- a/src/scitex/dsp/_modulation_index.py +++ /dev/null @@ -1,93 +0,0 @@ -#!/usr/bin/env python3 -# Time-stamp: "2024-11-04 02:09:55 (ywatanabe)" -# File: ./scitex_repo/src/scitex/dsp/_modulation_index.py - -try: - import torch - - TORCH_AVAILABLE = True -except ImportError: - TORCH_AVAILABLE = False - torch = None - -from scitex.decorators import signal_fn - -if TORCH_AVAILABLE: - from scitex.nn._ModulationIndex import ModulationIndex - - -@signal_fn -def modulation_index(pha, amp, n_bins=18, amp_prob=False): - """ - pha: (batch_size, n_chs, n_freqs_pha, n_segments, seq_len) - amp: (batch_size, n_chs, n_freqs_amp, n_segments, seq_len) - """ - if not TORCH_AVAILABLE: - raise ImportError( - "PyTorch is not installed. Please install with: pip install torch" - ) - return ModulationIndex(n_bins=n_bins, amp_prob=amp_prob)(pha, amp) - - -def _reshape(x, batch_size=2, n_chs=4): - return ( - torch.tensor(x) - .float() - .unsqueeze(0) - .unsqueeze(0) - .repeat(batch_size, n_chs, 1, 1, 1) - ) - - -if __name__ == "__main__": - import sys - - import matplotlib.pyplot as plt - - import scitex - - # Start - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start( - sys, plt, fig_scale=3 - ) - - # Parameters - FS = 512 - T_SEC = 5 - - # Demo signal - xx, tt, fs = scitex.dsp.demo_sig(fs=FS, t_sec=T_SEC, sig_type="tensorpac") - # xx.shape: (8, 19, 20, 512) - - # Tensorpac - ( - pha, - amp, - freqs_pha, - freqs_amp, - pac_tp, - ) = scitex.dsp.utils.pac.calc_pac_with_tensorpac(xx, fs, t_sec=T_SEC) - - # GPU calculation with scitex.dsp.nn.ModulationIndex - pha, amp = _reshape(pha), _reshape(amp) - pac_scitex = scitex.dsp.modulation_index(pha, amp).cpu().numpy() - i_batch, i_ch = 0, 0 - pac_scitex = pac_scitex[i_batch, i_ch] - - # Plots - fig = scitex.dsp.utils.pac.plot_PAC_scitex_vs_tensorpac( - pac_scitex, pac_tp, freqs_pha, freqs_amp - ) - fig.suptitle("MI (modulation index) calculation") - scitex.io.save(fig, "modulation_index.png") - - # Close - scitex.session.close(CONFIG) - -# EOF - -""" -/home/ywatanabe/proj/entrance/scitex/dsp/_modulation_index.py -""" - -# EOF diff --git a/src/scitex/dsp/_pac.py b/src/scitex/dsp/_pac.py deleted file mode 100755 index d4e31dcf8..000000000 --- a/src/scitex/dsp/_pac.py +++ /dev/null @@ -1,337 +0,0 @@ -#!/usr/bin/env python3 -# Time-stamp: "2024-11-26 22:24:40 (ywatanabe)" -# File: ./scitex_repo/src/scitex/dsp/_pac.py - -THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/dsp/_pac.py" - -import sys - -import matplotlib.pyplot as plt -import numpy as np - -try: - import torch - - TORCH_AVAILABLE = True -except ImportError: - TORCH_AVAILABLE = False - torch = None - - -from scitex.decorators import signal_fn - -if TORCH_AVAILABLE: - from scitex.nn._PAC import PAC - - -def _check_torch(): - if not TORCH_AVAILABLE: - raise ImportError( - "PyTorch is not installed. Please install with: pip install torch" - ) - - -""" -scitex.dsp.pac function -""" - - -# @batch_fn -@signal_fn -def pac( - x, - fs, - pha_start_hz=2, - pha_end_hz=20, - pha_n_bands=100, - amp_start_hz=60, - amp_end_hz=160, - amp_n_bands=100, - device="cuda", - batch_size=1, - batch_size_ch=-1, - fp16=False, - trainable=False, - n_perm=None, - amp_prob=False, -): - """ - Compute the phase-amplitude coupling (PAC) for signals. This function automatically handles inputs as - PyTorch tensors, NumPy arrays, or pandas DataFrames. - - Arguments: - - x (torch.Tensor | np.ndarray | pd.DataFrame): Input signal. Shape can be either (batch_size, n_chs, seq_len) or - - fs (float): Sampling frequency of the input signal. - - pha_start_hz (float, optional): Start frequency for phase bands. Default is 2 Hz. - - pha_end_hz (float, optional): End frequency for phase bands. Default is 20 Hz. - - pha_n_bands (int, optional): Number of phase bands. Default is 100. - - amp_start_hz (float, optional): Start frequency for amplitude bands. Default is 60 Hz. - - amp_end_hz (float, optional): End frequency for amplitude bands. Default is 160 Hz. - - amp_n_bands (int, optional): Number of amplitude bands. Default is 100. - - Returns: - - torch.Tensor: PAC values. Shape: (batch_size, n_chs, pha_n_bands, amp_n_bands) - - numpy.ndarray: Phase bands used for the computation. - - numpy.ndarray: Amplitude bands used for the computation. - - Example: - FS = 512 - T_SEC = 4 - xx, tt, fs = scitex.dsp.demo_sig( - batch_size=1, n_chs=1, fs=FS, t_sec=T_SEC, sig_type="tensorpac" - ) - pac, pha_mids_hz, amp_mids_hz = scitex.dsp.pac(xx, fs) - """ - _check_torch() - - def process_ch_batching(m, x, batch_size_ch, device): - n_chs = x.shape[1] - n_batches = (n_chs + batch_size_ch - 1) // batch_size_ch - - agg = [] - for ii in range(n_batches): - start, end = batch_size_ch * ii, min(batch_size_ch * (ii + 1), n_chs) - _pac = m(x[:, start:end, :].to(device)).detach().cpu() - agg.append(_pac) - - # return np.concatenate(agg, axis=1) - return torch.cat(agg, dim=1) - - m = PAC( - x.shape[-1], - fs, - pha_start_hz=pha_start_hz, - pha_end_hz=pha_end_hz, - pha_n_bands=pha_n_bands, - amp_start_hz=amp_start_hz, - amp_end_hz=amp_end_hz, - amp_n_bands=amp_n_bands, - fp16=fp16, - trainable=trainable, - n_perm=n_perm, - amp_prob=amp_prob, - ).to(device) - - if batch_size_ch == -1: - return m(x.to(device)), m.PHA_MIDS_HZ, m.AMP_MIDS_HZ - else: - return ( - process_ch_batching(m, x, batch_size_ch, device), - m.PHA_MIDS_HZ, - m.AMP_MIDS_HZ, - ) - - -if __name__ == "__main__": - import matplotlib.pyplot as plt - - import scitex - - # Start - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) - - pac, freqs_pha, freqs_amp = scitex.dsp.pac( - np.random.rand(1, 16, 24000), - 400, - batch_size=1, - batch_size_ch=8, - fp16=True, - n_perm=16, - ) - -# # Parameters -# FS = 512 -# T_SEC = 4 -# # IS_TRAINABLE = False -# # FP16 = True - -# for IS_TRAINABLE in [True, False]: -# for FP16 in [True, False]: - -# # Demo signal -# xx, tt, fs = scitex.dsp.demo_sig( -# batch_size=1, n_chs=1, fs=FS, t_sec=T_SEC, sig_type="pac" -# ) - - -# # scitex.str.print_debug() -# # xx = np.random.rand(1,16,24000) -# # fs = 400 - -# # scitex calculation -# pac_scitex, pha_mids_scitex, amp_mids_scitex = scitex.dsp.pac( -# xx, -# fs, -# pha_n_bands=50, -# amp_n_bands=30, -# trainable=IS_TRAINABLE, -# fp16=FP16, -# ) -# i_batch, i_ch = 0, 0 -# pac_scitex = pac_scitex[i_batch, i_ch] - -# printc(type(pac_scitex)) - -# # Tensorpac calculation -# ( -# _, -# _, -# _pha_mids_tp, -# _amp_mids_tp, -# pac_tp, -# ) = scitex.dsp.utils.pac.calc_pac_with_tensorpac(xx, fs, T_SEC) - -# # Validates the consitency in frequency definitions -# assert np.allclose( -# pha_mids_scitex, _pha_mids_tp -# ) -# assert np.allclose( -# amp_mids_scitex, _amp_mids_tp -# ) - -# scitex.io.save( -# (pac_scitex, pac_tp, pha_mids_scitex, amp_mids_scitex), -# "./data/cache.npz", -# ) - -# # ################################################################################ -# # # cache -# # pac_scitex, pac_tp, pha_mids_scitex, amp_mids_scitex = scitex.io.load( -# # "./data/cache.npz" -# # ) -# # ################################################################################ - -# # Plots -# fig = scitex.dsp.utils.pac.plot_PAC_scitex_vs_tensorpac( -# pac_scitex, pac_tp, pha_mids_scitex, amp_mids_scitex -# ) -# fig.suptitle( -# "Phase-Amplitude Coupling calculation\n\n(Bandpass Filtering -> Hilbert Transformation-> Modulation Index)" -# ) -# plt.show() - -# scitex.gen.reload(scitex.dsp) - -# # Saves the figure -# trainable_str = "trainable" if IS_TRAINABLE else "static" -# fp_str = "fp16" if FP16 else "fp32" -# scitex.io.save( -# fig, f"pac_with_{trainable_str}_bandpass_{fp_str}.png" -# ) - - -# def run_method_tests(): -# import scitex - -# # Test parameters -# FS = 512 -# T_SEC = 4 - -# class PACProcessor: -# @batch_torch_fn -# def process_pac(self, x, fs, **kwargs): -# return pac(x, fs, **kwargs) - -# @signal_fn -# def process_signal(self, x): -# return x * 2 - -# def run_method_basic_tests(): -# processor = PACProcessor() - -# # Generate test signal -# xx, tt, fs = scitex.dsp.demo_sig( -# batch_size=1, n_chs=1, fs=FS, t_sec=T_SEC, sig_type="pac" -# ) - -# try: -# # Test method with batch processing -# result_batch, pha_mids, amp_mids = processor.process_pac( -# xx, fs, pha_n_bands=50, amp_n_bands=30, batch_size=1 -# ) -# assert torch.is_tensor(result_batch) - -# # Test basic torch method -# result_torch = processor.process_signal(xx) -# assert torch.is_tensor(result_torch) - -# scitex.str.printc("Passed: Basic method tests", "yellow") -# except Exception as err: -# scitex.str.printc(f"Failed: Basic method tests - {str(err)}", "red") - -# def run_method_cuda_tests(): -# if not torch.cuda.is_available(): -# scitex.str.printc( -# "CUDA method tests skipped: No GPU available", "yellow" -# ) -# return - -# processor = PACProcessor() -# xx, tt, fs = scitex.dsp.demo_sig( -# batch_size=1, n_chs=1, fs=FS, t_sec=T_SEC, sig_type="pac" -# ) - -# try: -# # Test with CUDA -# result_cuda, _, _ = processor.process_pac(xx, fs, device="cuda") -# assert result_cuda.device.type == "cuda" - -# result_torch = processor.process_signal(xx, device="cuda") -# assert result_torch.device.type == "cuda" - -# scitex.str.printc("Passed: CUDA method tests", "yellow") -# except Exception as err: -# scitex.str.printc(f"Failed: CUDA method tests - {str(err)}", "red") - -# def run_method_batch_size_tests(): -# processor = PACProcessor() -# batch_sizes = [1, 2, 4] - -# for batch_size in batch_sizes: -# try: -# xx, tt, fs = scitex.dsp.demo_sig( -# batch_size=batch_size, -# n_chs=1, -# fs=FS, -# t_sec=T_SEC, -# sig_type="pac", -# ) - -# result, _, _ = processor.process_pac( -# xx, fs, batch_size=batch_size -# ) -# assert result.shape[0] == batch_size - -# scitex.str.printc( -# f"Passed: Method batch size test with size={batch_size}", -# "yellow", -# ) -# except Exception as err: -# scitex.str.printc( -# f"Failed: Method batch size test with size={batch_size} - {str(err)}", -# "red", -# ) - -# # Execute method test suites -# test_suites = [ -# ("Method Basic Tests", run_method_basic_tests), -# ("Method CUDA Tests", run_method_cuda_tests), -# ("Method Batch Size Tests", run_method_batch_size_tests), -# ] - -# for test_name, test_func in test_suites: -# test_func() - - -# if __name__ == "__main__": -# run_method_tests() - -# # EOF - -# """ -# python -m scitex.dsp._pac -# """ - -# - -# EOF diff --git a/src/scitex/dsp/_psd.py b/src/scitex/dsp/_psd.py deleted file mode 100755 index 0e92f33e0..000000000 --- a/src/scitex/dsp/_psd.py +++ /dev/null @@ -1,114 +0,0 @@ -#!/usr/bin/env python3 -# Time-stamp: "2024-11-04 02:11:25 (ywatanabe)" -# File: ./scitex_repo/src/scitex/dsp/_psd.py - -"""This script does XYZ.""" - -try: - import torch - - TORCH_AVAILABLE = True -except ImportError: - TORCH_AVAILABLE = False - torch = None - -from scitex.decorators import signal_fn - -if TORCH_AVAILABLE: - from scitex.nn._PSD import PSD - - -@signal_fn -def psd( - x, - fs, - prob=False, - dim=-1, -): - """ - import matplotlib.pyplot as plt - - x, t, fs = scitex.dsp.demo_sig() # (batch_size, n_chs, seq_len) - pp, ff = psd(x, fs) - - # Plots - plt, CC = scitex.plt.configure_mpl(plt) - fig, ax = scitex.plt.subplots() - ax.plot(fs, pp[0, 0]) - ax.xlabel("Frequency [Hz]") - ax.ylabel("log(Power [uV^2 / Hz]) [a.u.]") - plt.show() - """ - if not TORCH_AVAILABLE: - raise ImportError( - "PyTorch is not installed. Please install with: pip install torch" - ) - psd, freqs = PSD(fs, prob=prob, dim=dim)(x) - return psd, freqs - - -def band_powers(self, psd): - """ - Calculate the average power for specified frequency bands. - """ - assert len(self.low_freqs) == len(self.high_freqs) - - out = [] - for ll, hh in zip(self.low_freqs, self.high_freqs): - band_indices = torch.where((freqs >= ll) & (freqs <= hh))[0].to(psd.device) - band_power = psd[..., band_indices].sum(dim=self.dim) - bandwidth = hh - ll - avg_band_power = band_power / bandwidth - out.append(avg_band_power) - out = torch.stack(out, dim=-1) - return out - - # Average Power in Each Frequency Band - avg_band_powers = self.calc_band_avg_power(psd, freqs) - return (avg_band_powers,) - - -if __name__ == "__main__": - import sys - - import matplotlib.pyplot as plt - - import scitex - - # Start - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) - - # Parameters - SIG_TYPE = "chirp" - - # Demo signal - xx, tt, fs = scitex.dsp.demo_sig(SIG_TYPE) # (8, 19, 384) - - # PSD calculation - pp, ff = psd(xx, fs, prob=True) - - # Plots - fig, axes = scitex.plt.subplots(nrows=2) - - axes[0].plot(tt, xx[0, 0], label=SIG_TYPE) - axes[1].set_title("Signal") - axes[0].set_xlabel("Time [s]") - axes[0].set_ylabel("Amplitude [?V]") - - axes[1].plot(ff, pp[0, 0]) - axes[1].set_title("PSD (power spectrum density)") - axes[1].set_xlabel("Frequency [Hz]") - axes[1].set_ylabel("Log(Power [?V^2 / Hz]) [a.u.]") - - scitex.io.save(fig, "psd.png") - - # Close - scitex.session.close(CONFIG) - -# EOF - -""" -/home/ywatanabe/proj/entrance/scitex/dsp/_psd.py -""" - -# EOF diff --git a/src/scitex/dsp/_resample.py b/src/scitex/dsp/_resample.py deleted file mode 100755 index 14f5540e7..000000000 --- a/src/scitex/dsp/_resample.py +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env python3 -# Time-stamp: "2024-04-13 02:35:11 (ywatanabe)" - -try: - import torch - - TORCH_AVAILABLE = True -except ImportError: - TORCH_AVAILABLE = False - torch = None - -try: - import torchaudio.transforms as T - - TORCHAUDIO_AVAILABLE = True -except ImportError: - TORCHAUDIO_AVAILABLE = False - T = None - -import scitex -from scitex.decorators import signal_fn - - -@signal_fn -def resample(x, src_fs, tgt_fs, t=None): - if not TORCH_AVAILABLE: - raise ImportError( - "PyTorch is not installed. Please install with: pip install torch" - ) - if not TORCHAUDIO_AVAILABLE: - raise ImportError( - "torchaudio is not installed. Please install with: pip install torchaudio" - ) - xr = T.Resample(src_fs, tgt_fs, dtype=x.dtype).to(x.device)(x) - if t is None: - return xr - if t is not None: - tr = torch.linspace(t[0], t[-1], xr.shape[-1]) - return xr, tr - - -if __name__ == "__main__": - import sys - - import matplotlib.pyplot as plt - - # Start - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) - - # Parameters - T_SEC = 1 - SIG_TYPE = "chirp" - SRC_FS = 128 - TGT_FS_UP = 256 - TGT_FS_DOWN = 64 - FREQS_HZ = [10, 30, 100, 300] - - # Demo Signal - xx, tt, fs = scitex.dsp.demo_sig( - t_sec=T_SEC, fs=SRC_FS, freqs_hz=FREQS_HZ, sig_type=SIG_TYPE - ) - - # Resampling - xd, td = scitex.dsp.resample(xx, fs, TGT_FS_DOWN, t=tt) - xu, tu = scitex.dsp.resample(xx, fs, TGT_FS_UP, t=tt) - - # Plots - i_batch, i_ch = 0, 0 - fig, axes = plt.subplots(nrows=3, sharex=True, sharey=True) - axes[0].plot(tt, xx[i_batch, i_ch], label=f"Original ({SRC_FS} Hz)") - axes[1].plot(td, xd[i_batch, i_ch], label=f"Down-sampled ({TGT_FS_DOWN} Hz)") - axes[2].plot(tu, xu[i_batch, i_ch], label=f"Up-sampled ({TGT_FS_UP} Hz)") - for ax in axes: - ax.legend(loc="upper left") - - axes[-1].set_xlabel("Time [s]") - fig.supylabel("Amplitude [?V]") - fig.suptitle("Resampling") - scitex.io.save(fig, "traces.png") - # plt.show() - -# EOF - -""" -/home/ywatanabe/proj/entrance/scitex/dsp/_resample.py -""" diff --git a/src/scitex/dsp/_skills/SKILL.md b/src/scitex/dsp/_skills/SKILL.md deleted file mode 100644 index ad3348a91..000000000 --- a/src/scitex/dsp/_skills/SKILL.md +++ /dev/null @@ -1,56 +0,0 @@ ---- -name: stx.dsp -description: Digital signal processing for neuroscience — filtering, spectral analysis, phase-amplitude coupling, ripple detection, wavelets, and resampling. ---- - -# stx.dsp — Skill Index - -Digital signal processing (DSP) utilities for neuroscience and time-series analysis. All major functions accept NumPy arrays, PyTorch tensors, or pandas DataFrames via the `@signal_fn` decorator and return the same type as input. - -## Sub-skills - -| File | Feature Area | -|------|-------------| -| [filtering.md](filtering.md) | Bandpass, bandstop, lowpass, highpass, Gaussian filters | -| [spectral.md](spectral.md) | Power spectral density and band power extraction | -| [hilbert.md](hilbert.md) | Analytic signal: amplitude envelope and instantaneous phase | -| [pac.md](pac.md) | Phase-amplitude coupling (`pac`, `modulation_index`) | -| [ripple-detection.md](ripple-detection.md) | Hippocampal sharp-wave ripple detection | -| [wavelet.md](wavelet.md) | Continuous wavelet transform | -| [resampling.md](resampling.md) | Anti-aliased up/down resampling | -| [segmentation.md](segmentation.md) | Sliding-window segmentation and sktime conversion | -| [noise.md](noise.md) | Add Gaussian, white, pink, or brown noise | -| [normalization.md](normalization.md) | Z-score and min-max normalization | -| [referencing.md](referencing.md) | Common-average, random, and target re-referencing | -| [demo-signal.md](demo-signal.md) | Synthetic signal generation for testing | -| [params.md](params.md) | Built-in EEG frequency bands and electrode montages | -| [utils.md](utils.md) | Helpers: zero-padding, FIR filter design, differentiable bandpass filters | - -## Quick Start - -```python -import scitex as stx -import numpy as np - -# Generate demo signal: shape (batch=8, chs=19, time=2048) -xx, tt, fs = stx.dsp.demo_sig(sig_type="chirp", fs=512, t_sec=4) - -# Bandpass filter 8-30 Hz -xx_bp = stx.dsp.filt.bandpass(xx, fs, np.array([[8, 30]])) - -# Power spectral density -psd_vals, freqs = stx.dsp.psd(xx, fs) - -# Wavelet transform -> phase, amplitude, frequency axis -pha, amp, freqs_w = stx.dsp.wavelet(xx, fs) -``` - -## Optional Dependencies - -| Feature | Requires | Install | -|---------|----------|---------| -| Filters, PSD, PAC, wavelet, resample | `torch`, `torchaudio` | `pip install torch torchaudio` | -| Audio device listing | `sounddevice`, PortAudio | `pip install sounddevice` + `apt install portaudio19-dev` | -| EEG electrode positions | `mne` | `pip install mne` | -| Ripple demo signal | `ripple_detection` | `pip install ripple_detection` | -| Tensorpac demo / PAC comparison | `tensorpac` | `pip install tensorpac` | diff --git a/src/scitex/dsp/_skills/demo-signal.md b/src/scitex/dsp/_skills/demo-signal.md deleted file mode 100644 index 3039714f7..000000000 --- a/src/scitex/dsp/_skills/demo-signal.md +++ /dev/null @@ -1,105 +0,0 @@ ---- -description: Generate synthetic signals for testing — periodic, chirp, ripple, Gaussian, PAC, MEG. ---- - -# stx.dsp.demo_sig — Demo Signal Generation - -Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_demo_sig.py` - -## Signature - -```python -xx, tt, fs = stx.dsp.demo_sig( - sig_type="periodic", - batch_size=8, - n_chs=19, - n_segments=20, - t_sec=4, - fs=512, - freqs_hz=None, - verbose=False, -) -``` - -### Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `sig_type` | str | `"periodic"` | Signal type (see table below) | -| `batch_size` | int | `8` | Number of batches | -| `n_chs` | int | `19` | Number of channels | -| `n_segments` | int | `20` | Segments (only for `"tensorpac"` and `"pac"`) | -| `t_sec` | float | `4` | Duration in seconds | -| `fs` | int | `512` | Sampling frequency in Hz | -| `freqs_hz` | list or None | `None` | Frequencies for periodic signal; random if `None` | -| `verbose` | bool | `False` | Print frequency information | - -### Returns - -- `xx`: signal array, shape depends on `sig_type` (see table) -- `tt`: time vector, shape `(t_sec * fs,)` -- `fs`: sampling frequency (same as input) - -## Signal types - -| `sig_type` | Output shape | Requires | Description | -|------------|-------------|----------|-------------| -| `"uniform"` | `(batch, chs, time)` | — | Uniform random in `[-0.5, 0.5]` | -| `"gauss"` | `(batch, chs, time)` | — | Standard Gaussian noise | -| `"periodic"` | `(batch, chs, time)` | — | Sum of sine waves at `freqs_hz` | -| `"chirp"` | `(batch, chs, time)` | — | Linear frequency sweep with AM envelope | -| `"ripple"` | `(batch, chs, time)` | `ripple_detection` | Simulated hippocampal LFP with ripples | -| `"meg"` | `(batch, chs, time)` | `mne` | Real MEG segment from MNE sample dataset | -| `"tensorpac"` | `(batch, chs, segs, time)` | `tensorpac` | PAC signal via `pac_signals_wavelet` | -| `"pac"` | `(batch, chs, segs, time)` | — | Synthetic PAC: theta phase modulating gamma amplitude | - -## Examples - -```python -import scitex as stx - -# Default: 8 batches, 19 channels, 4s at 512 Hz -xx, tt, fs = stx.dsp.demo_sig() -print(xx.shape) # (8, 19, 2048) - -# Chirp (frequency sweep) -xx, tt, fs = stx.dsp.demo_sig(sig_type="chirp", fs=512, t_sec=2) - -# Periodic with specific frequencies -xx, tt, fs = stx.dsp.demo_sig( - sig_type="periodic", - freqs_hz=[10, 30, 100, 300], - fs=1024, - t_sec=1, - batch_size=4, - n_chs=2, -) - -# PAC signal with segment dimension (for modulation_index) -xx, tt, fs = stx.dsp.demo_sig(sig_type="pac", n_segments=20, fs=512, t_sec=4) -print(xx.shape) # (8, 19, 20, 2048) - -# PAC signal using tensorpac wavelet method -xx, tt, fs = stx.dsp.demo_sig(sig_type="tensorpac", n_segments=20) - -# Ripple simulation -xx, tt, fs = stx.dsp.demo_sig(sig_type="ripple", fs=1000, t_sec=10) -``` - -## Internal signal constructors - -These private functions can be called directly for 1D signals: - -```python -from scitex.dsp._demo_sig import ( - _demo_sig_periodic_1d, - _demo_sig_chirp_1d, - _demo_sig_ripple_1d, -) - -# Single channel periodic -sig_1d = _demo_sig_periodic_1d(t_sec=2, fs=512, freqs_hz=[10, 40]) - -# Single channel chirp -sig_chirp = _demo_sig_chirp_1d(t_sec=2, fs=512, low_hz=5, high_hz=200) -``` diff --git a/src/scitex/dsp/_skills/filtering.md b/src/scitex/dsp/_skills/filtering.md deleted file mode 100644 index db5f3c711..000000000 --- a/src/scitex/dsp/_skills/filtering.md +++ /dev/null @@ -1,110 +0,0 @@ ---- -description: Bandpass, bandstop, lowpass, highpass, and Gaussian filters for multi-channel signals. ---- - -# stx.dsp.filt — Filtering - -Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/filt.py` - -All filter functions are decorated with `@signal_fn`, which means they accept NumPy arrays, PyTorch tensors, or pandas DataFrames and return the same type. Input shape must be `(batch_size, n_chs, seq_len)` or compatible broadcastable form. - -Filters are implemented as PyTorch neural network modules from `scitex.nn._Filters`. They require `torch`. - -## Function Signatures - -```python -stx.dsp.filt.bandpass(x, fs, bands, t=None) -stx.dsp.filt.bandstop(x, fs, bands, t=None) -stx.dsp.filt.lowpass(x, fs, cutoffs_hz, t=None) -stx.dsp.filt.highpass(x, fs, cutoffs_hz, t=None) -stx.dsp.filt.gauss(x, sigma, t=None) -``` - -### Parameters - -| Parameter | Type | Description | -|-----------|------|-------------| -| `x` | ndarray / Tensor | Input signal, shape `(batch, chs, time)` | -| `fs` | float | Sampling frequency in Hz | -| `bands` | array `(n_bands, 2)` | `[[low_hz, high_hz], ...]` for bandpass/bandstop | -| `cutoffs_hz` | array `(n_bands,)` | Cutoff frequencies for lowpass/highpass | -| `sigma` | float | Gaussian kernel width in samples (standard deviations) | -| `t` | ndarray or None | Optional time vector; if given, also returned | - -### Return values - -- Without `t`: filtered signal, same shape and type as input -- With `t`: `(filtered_signal, time_vector)` — time vector is unchanged - -## Examples - -```python -import scitex as stx -import numpy as np - -xx, tt, fs = stx.dsp.demo_sig(sig_type="periodic", fs=1024, t_sec=1) -# xx.shape: (8, 19, 1024) - -# Single band: [[low_hz, high_hz]] -BANDS = np.array([[80, 310]]) - -# Bandpass 80-310 Hz -x_bp = stx.dsp.filt.bandpass(xx, fs, BANDS) - -# Bandstop 80-310 Hz (notch) -x_bs = stx.dsp.filt.bandstop(xx, fs, BANDS) - -# Lowpass at 80 Hz -x_lp = stx.dsp.filt.lowpass(xx, fs, BANDS[:, 0]) - -# Highpass at 310 Hz -x_hp = stx.dsp.filt.highpass(xx, fs, BANDS[:, 1]) - -# Gaussian smoothing (sigma=3 samples) -x_g = stx.dsp.filt.gauss(xx, sigma=3) - -# With time vector returned -x_bp, t_bp = stx.dsp.filt.bandpass(xx, fs, BANDS, t=tt) -``` - -## Multi-band filtering - -`bandpass` and `bandstop` accept multiple bands at once. The output gains an extra dimension for each band: - -```python -BANDS = np.array([[4, 8], [8, 13], [13, 30]]) # theta, alpha, beta -x_multi = stx.dsp.filt.bandpass(xx, fs, BANDS) -# x_multi.shape: (8, 19, 3, 1024) — extra dim for n_bands -``` - -## EEG use case - -```python -# Ripple band detection preprocessing -ripple_bands = np.array([[80, 140]]) -x_ripple = stx.dsp.filt.bandpass(xx, fs, ripple_bands) - -# Theta band for PAC phase -theta_bands = np.array([[4, 8]]) -x_theta = stx.dsp.filt.bandpass(xx, fs, theta_bands) -``` - -## FIR filter design utility - -`stx.dsp.utils.filter.design_filter` exposes the underlying FIR design for inspection: - -```python -from scitex.dsp.utils.filter import design_filter, plot_filter_responses - -xx, tt, fs = stx.dsp.demo_sig() -seq_len = xx.shape[-1] - -# Returns filter coefficients (numpy array) -bp_filter = design_filter(seq_len, fs, low_hz=30, high_hz=70) -lp_filter = design_filter(seq_len, fs, low_hz=30) -hp_filter = design_filter(seq_len, fs, high_hz=70) -bs_filter = design_filter(seq_len, fs, low_hz=30, high_hz=70, is_bandstop=True) - -# Plot impulse + frequency response -fig = plot_filter_responses(bp_filter, fs, title="Bandpass 30-70 Hz") -``` diff --git a/src/scitex/dsp/_skills/hilbert.md b/src/scitex/dsp/_skills/hilbert.md deleted file mode 100644 index 8004c2728..000000000 --- a/src/scitex/dsp/_skills/hilbert.md +++ /dev/null @@ -1,72 +0,0 @@ ---- -description: Hilbert transform returning instantaneous phase and amplitude envelope. ---- - -# stx.dsp.hilbert — Analytic Signal - -Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_hilbert.py` - -## Signature - -```python -phase, amplitude = stx.dsp.hilbert(x, dim=-1) -``` - -Computes the analytic signal via the Hilbert transform using `scitex.nn._Hilbert`. Both outputs have the same shape as `x`. - -### Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `x` | ndarray / Tensor | required | Input signal, shape `(batch, chs, time)` | -| `dim` | int | `-1` | Dimension along which to apply the transform | - -### Returns - -- `phase`: instantaneous phase in radians, same shape as `x` -- `amplitude`: amplitude envelope (always non-negative), same shape as `x` - -Decorated with `@signal_fn`: accepts NumPy arrays, PyTorch tensors, or DataFrames; returns the same type. - -## Examples - -```python -import scitex as stx - -xx, tt, fs = stx.dsp.demo_sig(sig_type="chirp", t_sec=1.0, fs=400) -# xx.shape: (8, 19, 400) - -phase, amplitude = stx.dsp.hilbert(xx) -# phase.shape: (8, 19, 400) — values in [-pi, pi] -# amplitude.shape: (8, 19, 400) — non-negative envelope - -# Plot signal and envelope for first batch/channel -import matplotlib.pyplot as plt -fig, axes = plt.subplots(2, 1, sharex=True) -axes[0].plot(tt, xx[0, 0], label="signal") -axes[0].plot(tt, amplitude[0, 0], label="envelope") -axes[0].legend() -axes[1].plot(tt, phase[0, 0], label="phase [rad]") -axes[1].legend() -``` - -## PAC preprocessing pipeline - -`hilbert` is used internally by `pac` and `detect_ripples`, but you can also call it directly for custom phase-amplitude analyses: - -```python -import scitex as stx -import numpy as np - -xx, tt, fs = stx.dsp.demo_sig(fs=512, t_sec=4) - -# Extract theta phase -theta = stx.dsp.filt.bandpass(xx, fs, np.array([[4, 8]])) -theta_phase, _ = stx.dsp.hilbert(theta) - -# Extract gamma amplitude -gamma = stx.dsp.filt.bandpass(xx, fs, np.array([[32, 80]])) -_, gamma_amp = stx.dsp.hilbert(gamma) - -# theta_phase and gamma_amp are ready for coupling analysis -``` diff --git a/src/scitex/dsp/_skills/noise.md b/src/scitex/dsp/_skills/noise.md deleted file mode 100644 index 9eacce2bd..000000000 --- a/src/scitex/dsp/_skills/noise.md +++ /dev/null @@ -1,72 +0,0 @@ ---- -description: Add Gaussian, white, pink, or brown noise to signals. ---- - -# stx.dsp.add_noise — Noise Addition - -Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/add_noise.py` - -All functions are decorated with `@signal_fn`, accept NumPy arrays or PyTorch tensors, and return the same type. Requires `torch`. - -## Functions - -```python -noisy = stx.dsp.add_noise.gauss(x, amp=1.0) -noisy = stx.dsp.add_noise.white(x, amp=1.0) -noisy = stx.dsp.add_noise.pink(x, amp=1.0, dim=-1) -noisy = stx.dsp.add_noise.brown(x, amp=1.0, dim=-1) -``` - -### Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `x` | ndarray / Tensor | required | Input signal | -| `amp` | float | `1.0` | Noise amplitude (scale factor) | -| `dim` | int | `-1` | Dimension along which to generate correlated noise (pink, brown only) | - -### Returns - -Signal with noise added, same shape and type as input. - -## Noise types - -| Function | Type | Spectrum | -|----------|------|---------| -| `gauss` | Gaussian | White (flat), samples from `N(0, amp)` | -| `white` | Uniform | White (flat), samples from `U(-amp, amp)` | -| `pink` | 1/f noise | Pink spectrum: power `~ 1/f` | -| `brown` | Brownian | Red spectrum: cumulative sum of uniform, then min-max normalized | - -## Examples - -```python -import scitex as stx - -xx, tt, fs = stx.dsp.demo_sig(fs=128, t_sec=1) - -# Add different noise types -xx_gauss = stx.dsp.add_noise.gauss(xx, amp=0.5) -xx_white = stx.dsp.add_noise.white(xx, amp=0.5) -xx_pink = stx.dsp.add_noise.pink(xx, amp=0.5) -xx_brown = stx.dsp.add_noise.brown(xx, amp=0.5) - -# Inspect noise alone -noise = stx.dsp.add_noise.pink(xx, amp=1.0) - xx -``` - -## Data augmentation use case - -```python -import scitex as stx -import numpy as np - -xx, tt, fs = stx.dsp.demo_sig(batch_size=8, fs=256, t_sec=2) - -# Augment by mixing noise types with different amplitudes -xx_aug_1 = stx.dsp.add_noise.gauss(xx, amp=0.1) # mild Gaussian -xx_aug_2 = stx.dsp.add_noise.pink(xx, amp=0.3) # realistic 1/f noise - -augmented = np.concatenate([xx, xx_aug_1, xx_aug_2], axis=0) -# augmented.shape: (24, 19, 512) -``` diff --git a/src/scitex/dsp/_skills/normalization.md b/src/scitex/dsp/_skills/normalization.md deleted file mode 100644 index 332593749..000000000 --- a/src/scitex/dsp/_skills/normalization.md +++ /dev/null @@ -1,78 +0,0 @@ ---- -description: Z-score and min-max normalization for multi-channel signals. ---- - -# stx.dsp.norm — Normalization - -Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/norm.py` - -Both functions are decorated with `@signal_fn`, accept NumPy arrays or PyTorch tensors, and return the same type. Requires `torch`. - -## Functions - -```python -x_z = stx.dsp.norm.z(x, dim=-1) -x_mm = stx.dsp.norm.minmax(x, amp=1.0, dim=-1, fn="mean") -``` - -### stx.dsp.norm.z - -Z-score normalization: `(x - mean) / std` along `dim`. - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `x` | ndarray / Tensor | required | Input signal | -| `dim` | int | `-1` | Dimension to normalize along | - -Returns tensor with mean 0 and std 1 along `dim`. - -### stx.dsp.norm.minmax - -Min-max normalization scaled to `[-amp, amp]` using the max absolute value. - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `x` | ndarray / Tensor | required | Input signal | -| `amp` | float | `1.0` | Output amplitude scale (result bounded by `[-amp, amp]`) | -| `dim` | int | `-1` | Dimension to normalize along | -| `fn` | str | `"mean"` | Unused (present for API compatibility) | - -Implementation: divides by `max(|max|, |min|)` so the output fits in `[-amp, amp]`. - -## Examples - -```python -import scitex as stx - -xx, tt, fs = stx.dsp.demo_sig(fs=256, t_sec=2) -# xx.shape: (8, 19, 512) - -# Z-score normalize along time dimension -xx_z = stx.dsp.norm.z(xx, dim=-1) -# Each channel has mean ~0 and std ~1 - -# Min-max normalize to [-1, 1] range -xx_mm = stx.dsp.norm.minmax(xx, amp=1.0, dim=-1) - -# Normalize along channel dimension (across channels per time point) -xx_ch = stx.dsp.norm.z(xx, dim=1) - -# Normalize each batch independently -xx_b = stx.dsp.norm.z(xx, dim=-1) -``` - -## Common use cases - -```python -import scitex as stx - -xx, tt, fs = stx.dsp.demo_sig() - -# Pre-normalize before bandpass filtering -xx_norm = stx.dsp.norm.z(xx) -xx_filt = stx.dsp.filt.bandpass(xx_norm, fs, [[4, 8]]) - -# Normalize after wavelet transform for visualization -pha, amp, freqs = stx.dsp.wavelet(xx, fs) -amp_norm = stx.dsp.norm.minmax(amp, amp=1.0) -``` diff --git a/src/scitex/dsp/_skills/pac.md b/src/scitex/dsp/_skills/pac.md deleted file mode 100644 index dc7f51917..000000000 --- a/src/scitex/dsp/_skills/pac.md +++ /dev/null @@ -1,154 +0,0 @@ ---- -description: Phase-amplitude coupling (PAC) via GPU-accelerated bandpass filtering and modulation index. ---- - -# stx.dsp — Phase-Amplitude Coupling - -Sources: -- `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_pac.py` -- `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_modulation_index.py` - -Both functions require `torch`. `pac` additionally benefits from CUDA. - -## stx.dsp.pac - -High-level end-to-end PAC from raw signal. - -```python -pac_vals, pha_mids_hz, amp_mids_hz = stx.dsp.pac( - x, - fs, - pha_start_hz=2, - pha_end_hz=20, - pha_n_bands=100, - amp_start_hz=60, - amp_end_hz=160, - amp_n_bands=100, - device="cuda", - batch_size=1, - batch_size_ch=-1, - fp16=False, - trainable=False, - n_perm=None, - amp_prob=False, -) -``` - -### Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `x` | ndarray / Tensor | required | Signal, shape `(batch, chs, time)` or `(batch, chs, segments, time)` | -| `fs` | float | required | Sampling frequency in Hz | -| `pha_start_hz` | float | `2` | Start of phase frequency range | -| `pha_end_hz` | float | `20` | End of phase frequency range | -| `pha_n_bands` | int | `100` | Number of phase bands to compute | -| `amp_start_hz` | float | `60` | Start of amplitude frequency range | -| `amp_end_hz` | float | `160` | End of amplitude frequency range | -| `amp_n_bands` | int | `100` | Number of amplitude bands to compute | -| `device` | str | `"cuda"` | PyTorch device, falls back to CPU if no GPU | -| `batch_size` | int | `1` | Batch size for processing | -| `batch_size_ch` | int | `-1` | Channel batch size; `-1` processes all at once | -| `fp16` | bool | `False` | Use half-precision for memory efficiency | -| `trainable` | bool | `False` | Make filter bank parameters learnable (gradient flows) | -| `n_perm` | int or None | `None` | Number of permutations for surrogate testing | -| `amp_prob` | bool | `False` | Normalize amplitude to probability distribution | - -### Returns - -- `pac_vals`: PAC values, shape `(batch, chs, pha_n_bands, amp_n_bands)` -- `pha_mids_hz`: center frequencies of phase bands, shape `(pha_n_bands,)` -- `amp_mids_hz`: center frequencies of amplitude bands, shape `(amp_n_bands,)` - -## stx.dsp.modulation_index - -Lower-level function: compute PAC from pre-computed phase and amplitude arrays. - -```python -mi = stx.dsp.modulation_index(pha, amp, n_bins=18, amp_prob=False) -``` - -### Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `pha` | Tensor | required | Phase signal, shape `(batch, chs, n_freqs_pha, n_segments, seq_len)` | -| `amp` | Tensor | required | Amplitude signal, shape `(batch, chs, n_freqs_amp, n_segments, seq_len)` | -| `n_bins` | int | `18` | Number of phase bins for the mean-vector length computation | -| `amp_prob` | bool | `False` | Normalize amplitude distribution | - -### Returns - -- `mi`: modulation index values - -## Helper: \_reshape - -```python -reshaped = stx.dsp._reshape(x, batch_size=2, n_chs=4) -``` - -Utility to reshape a raw PAC tensor `x` into `(batch, chs, ...)` format for `modulation_index`. - -## Examples - -### End-to-end PAC - -```python -import scitex as stx - -FS = 512 -xx, tt, fs = stx.dsp.demo_sig( - batch_size=1, n_chs=1, fs=FS, t_sec=4, sig_type="tensorpac" -) - -pac_vals, pha_mids, amp_mids = stx.dsp.pac( - xx, fs, - pha_start_hz=2, pha_end_hz=20, pha_n_bands=50, - amp_start_hz=60, amp_end_hz=160, amp_n_bands=30, -) -# pac_vals.shape: (1, 1, 50, 30) - -# Plot comodulogram -import matplotlib.pyplot as plt -fig, ax = plt.subplots() -im = ax.imshow(pac_vals[0, 0].T, origin="lower", aspect="auto") -ax.set_xlabel("Phase frequency [Hz]") -ax.set_ylabel("Amplitude frequency [Hz]") -plt.colorbar(im, label="PAC (MI)") -``` - -### Memory-efficient channel batching - -```python -pac_vals, pha_mids, amp_mids = stx.dsp.pac( - xx, fs, - batch_size_ch=4, # process 4 channels at a time - fp16=True, # halve memory usage -) -``` - -### Trainable filter banks (gradient-based optimization) - -```python -# Filters become nn.Parameters; gradients flow through -pac_vals, pha_mids, amp_mids = stx.dsp.pac( - xx, fs, trainable=True -) -pac_vals.sum().backward() # gradients available on pha_mids / amp_mids -``` - -### Compare with Tensorpac - -```python -from scitex.dsp.utils.pac import calc_pac_with_tensorpac, plot_pac_scitex_vs_tensorpac - -# Tensorpac reference calculation -phases, amplitudes, freqs_pha, freqs_amp, pac_tp = calc_pac_with_tensorpac( - xx, fs, t_sec=4, i_batch=0, i_ch=0 -) - -# Plot side-by-side comparison -fig = plot_pac_scitex_vs_tensorpac( - pac_vals[0, 0], pac_tp, freqs_pha, freqs_amp -) -``` diff --git a/src/scitex/dsp/_skills/params.md b/src/scitex/dsp/_skills/params.md deleted file mode 100644 index 0832b2dd3..000000000 --- a/src/scitex/dsp/_skills/params.md +++ /dev/null @@ -1,103 +0,0 @@ ---- -description: Built-in EEG frequency bands and standard electrode montages. ---- - -# stx.dsp.params — Parameters and Constants - -Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/params.py` - -## stx.dsp.params.BANDS - -Standard EEG frequency bands as a pandas DataFrame. - -```python -import scitex as stx - -print(stx.dsp.params.BANDS) -# delta theta lalpha halpha beta gamma -# low_hz 0.5 4.0 8.0 10.0 13.0 32.0 -# high_hz 4.0 8.0 10.0 13.0 32.0 75.0 -``` - -### Access patterns - -```python -bands = stx.dsp.params.BANDS - -# Get a single band's range -delta_low = bands["delta"]["low_hz"] # 0.5 -delta_high = bands["delta"]["high_hz"] # 4.0 - -# Get as array for use with stx.dsp.filt.bandpass -import numpy as np - -# All bands stacked -all_bands = bands.values.T # shape (6, 2): [[low1, high1], ...] - -# Single band -theta_band = np.array([[bands["theta"]["low_hz"], bands["theta"]["high_hz"]]]) -# [[4.0, 8.0]] - -# Multiple specific bands -gamma_beta = np.array([ - [bands["beta"]["low_hz"], bands["beta"]["high_hz"]], - [bands["gamma"]["low_hz"], bands["gamma"]["high_hz"]], -]) -``` - -### Use with filtering - -```python -import scitex as stx -import numpy as np - -xx, tt, fs = stx.dsp.demo_sig(fs=256, t_sec=4) - -bands = stx.dsp.params.BANDS - -# Bandpass filter to theta (4-8 Hz) -theta_band = np.array([[bands["theta"]["low_hz"], bands["theta"]["high_hz"]]]) -xx_theta = stx.dsp.filt.bandpass(xx, fs, theta_band) - -# Filter to all standard bands at once -all_bands_array = bands.values.T # (6, 2) -xx_all_bands = stx.dsp.filt.bandpass(xx, fs, all_bands_array) -# xx_all_bands.shape: (batch, chs, 6, time) -``` - -## stx.dsp.params.EEG_MONTAGE_1020 - -Standard 10-20 EEG electrode names (19 electrodes). - -```python -stx.dsp.params.EEG_MONTAGE_1020 -# ['FP1', 'F3', 'C3', 'P3', 'O1', -# 'FP2', 'F4', 'C4', 'P4', 'O2', -# 'F7', 'T7', 'P7', 'F8', 'T8', 'P8', -# 'FZ', 'CZ', 'PZ'] -``` - -Used as default in `stx.dsp.get_eeg_pos()`. - -## stx.dsp.params.EEG_MONTAGE_BIPOLAR_TRANVERSE - -Bipolar transverse montage (14 channel pairs). - -```python -stx.dsp.params.EEG_MONTAGE_BIPOLAR_TRANVERSE -# ['FP1-FP2', 'F7-F3', 'F3-FZ', 'FZ-F4', 'F4-F8', -# 'T7-C3', 'C3-CZ', 'CZ-C4', 'C4-T8', -# 'P7-P3', 'P3-PZ', 'PZ-P4', 'P4-P8', -# 'O1-O2'] -``` - -## stx.dsp.get_eeg_pos (optional, requires MNE) - -Get 3D electrode positions from the standard 10-20 montage. - -```python -df = stx.dsp.get_eeg_pos(channel_names=stx.dsp.params.EEG_MONTAGE_1020) -# Returns DataFrame with columns = channel names, rows = [x, y, z] -``` - -Raises `ImportError` if `mne` is not installed. diff --git a/src/scitex/dsp/_skills/referencing.md b/src/scitex/dsp/_skills/referencing.md deleted file mode 100644 index e83c01acb..000000000 --- a/src/scitex/dsp/_skills/referencing.md +++ /dev/null @@ -1,92 +0,0 @@ ---- -description: Common-average, random, and target channel re-referencing for EEG/LFP signals. ---- - -# stx.dsp.reference — Re-referencing - -Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/reference.py` - -All functions are decorated with `@torch_fn`, accept NumPy arrays or PyTorch tensors, and return the same type. They require `torch`. - -Re-referencing is applied along the channel dimension (`dim=-2` by default, i.e., the second-to-last axis in a `(batch, chs, time)` tensor). - -## Functions - -```python -re_ref = stx.dsp.reference.common_average(x, dim=-2) -re_ref = stx.dsp.reference.random(x, dim=-2) -re_ref = stx.dsp.reference.take_reference(x, tgt_indi, dim=-2) -``` - -### stx.dsp.reference.common_average - -Subtract the mean across all channels (common average reference), then z-score. - -Formula: `(x - mean(x, dim)) / std(x, dim)` - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `x` | ndarray / Tensor | required | Signal `(batch, chs, time)` | -| `dim` | int | `-2` | Channel dimension | - -### stx.dsp.reference.random - -Subtract a random permutation of the channel data from the original. - -Each call produces a different result (non-deterministic). Useful for data augmentation. - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `x` | ndarray / Tensor | required | Signal `(batch, chs, time)` | -| `dim` | int | `-2` | Channel dimension | - -### stx.dsp.reference.take_reference - -Subtract a specific channel (or set of channels) from all channels. - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `x` | ndarray / Tensor | required | Signal `(batch, chs, time)` | -| `tgt_indi` | int or slice | required | Index/indices of the reference channel(s) | -| `dim` | int | `-2` | Channel dimension | - -## Examples - -```python -import scitex as stx - -xx, tt, fs = stx.dsp.demo_sig(n_chs=19, fs=256, t_sec=2) -# xx.shape: (8, 19, 512) - -# Common average reference (standard for EEG) -xx_car = stx.dsp.reference.common_average(xx) -assert xx_car.shape == xx.shape - -# Random reference (data augmentation) -xx_rand = stx.dsp.reference.random(xx) - -# Reference to channel 0 (linked mastoid, for example) -xx_ref0 = stx.dsp.reference.take_reference(xx, tgt_indi=0) - -# Reference to average of channels 17 and 18 (bilateral mastoids) -xx_bm = stx.dsp.reference.take_reference(xx, tgt_indi=slice(17, 19)) -``` - -## EEG montage pipeline - -```python -import scitex as stx -import numpy as np - -# Load raw EEG (assumed shape: batch, chs, time) -xx, tt, fs = stx.dsp.demo_sig(sig_type="meg", n_chs=19, fs=256) - -# 1. Re-reference to common average -xx = stx.dsp.reference.common_average(xx) - -# 2. Bandpass filter to remove DC and high-frequency noise -xx = stx.dsp.filt.bandpass(xx, fs, np.array([[0.5, 80]])) - -# 3. Z-score normalize -xx = stx.dsp.norm.z(xx) -``` diff --git a/src/scitex/dsp/_skills/resampling.md b/src/scitex/dsp/_skills/resampling.md deleted file mode 100644 index fe0b533d2..000000000 --- a/src/scitex/dsp/_skills/resampling.md +++ /dev/null @@ -1,83 +0,0 @@ ---- -description: Anti-aliased signal resampling using torchaudio. ---- - -# stx.dsp.resample — Resampling - -Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_resample.py` - -## Signature - -```python -xr = stx.dsp.resample(x, src_fs, tgt_fs, t=None) -# or, with time vector: -xr, tr = stx.dsp.resample(x, src_fs, tgt_fs, t=tt) -``` - -Uses `torchaudio.transforms.Resample` for polyphase anti-aliased resampling. Decorated with `@signal_fn`. - -Requires `torch` and `torchaudio`. - -### Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `x` | ndarray / Tensor | required | Signal, shape `(batch, chs, time)` | -| `src_fs` | float | required | Source sampling frequency in Hz | -| `tgt_fs` | float | required | Target sampling frequency in Hz | -| `t` | ndarray or None | `None` | Optional time vector; if given, a resampled time vector is also returned | - -### Returns - -- `xr`: resampled signal, shape `(batch, chs, new_time)` where `new_time = round(time * tgt_fs / src_fs)` -- If `t` is provided: `(xr, tr)` where `tr` is a new time vector spanning the same range as `t` - -## Examples - -```python -import scitex as stx - -T_SEC = 1 -SRC_FS = 128 -xx, tt, fs = stx.dsp.demo_sig(sig_type="chirp", t_sec=T_SEC, fs=SRC_FS) - -# Downsample to 64 Hz -xd, td = stx.dsp.resample(xx, fs, 64, t=tt) -print(f"Original: {xx.shape}, {tt.shape}") -print(f"Downsampled: {xd.shape}, {td.shape}") - -# Upsample to 256 Hz -xu, tu = stx.dsp.resample(xx, fs, 256, t=tt) -print(f"Upsampled: {xu.shape}, {tu.shape}") - -# Without time vector -xd = stx.dsp.resample(xx, fs, 64) -``` - -## Common use cases - -```python -import scitex as stx -import numpy as np - -xx, tt, fs = stx.dsp.demo_sig(fs=1000, t_sec=10) - -# Preprocessing pipeline: downsample before filtering for speed -xx_256, tt_256 = stx.dsp.resample(xx, fs, 256, t=tt) -xx_filt = stx.dsp.filt.bandpass(xx_256, 256, np.array([[4, 80]])) - -# Resample before ripple detection (detect_ripples does this internally) -# but you can pre-downsample manually: -xx_low = stx.dsp.resample(xx, fs, 300) # ripple band is 80-140 Hz; 3x = 420 Hz - -# Match sampling rates between two recordings -xx_a, tt_a, fs_a = stx.dsp.demo_sig(fs=1024) -xx_b, tt_b, fs_b = stx.dsp.demo_sig(fs=512) -xx_a_resampled = stx.dsp.resample(xx_a, fs_a, fs_b) -``` - -## Notes - -- The resampled time vector `tr` is computed with `torch.linspace(t[0], t[-1], new_len)`, preserving the original time span. -- Resampling uses the dtype of the input tensor; convert to float32 first if needed. -- `detect_ripples` internally downsamples to `low_hz * 3` Hz automatically. diff --git a/src/scitex/dsp/_skills/ripple-detection.md b/src/scitex/dsp/_skills/ripple-detection.md deleted file mode 100644 index 9b9b7ff2f..000000000 --- a/src/scitex/dsp/_skills/ripple-detection.md +++ /dev/null @@ -1,122 +0,0 @@ ---- -description: Detect hippocampal sharp-wave ripples from wide-band LFP or EEG signals. ---- - -# stx.dsp.detect_ripples — Ripple Detection - -Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_detect_ripples.py` - -## Signature - -```python -df = stx.dsp.detect_ripples( - xx, - fs, - low_hz, - high_hz, - sd=2.0, - smoothing_sigma_ms=4, - min_duration_ms=10, - return_preprocessed_signal=False, -) -``` - -### Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `xx` | ndarray | required | Signal, shape `(n_chs, time)` or `(batch, n_chs, time)` | -| `fs` | float | required | Sampling frequency in Hz | -| `low_hz` | float | required | Lower edge of ripple band (e.g. `80`) | -| `high_hz` | float | required | Upper edge of ripple band (e.g. `140`) | -| `sd` | float | `2.0` | Threshold in standard deviations above mean for peak detection | -| `smoothing_sigma_ms` | float | `4` | Gaussian smoothing width in milliseconds | -| `min_duration_ms` | float | `10` | Minimum ripple duration in milliseconds | -| `return_preprocessed_signal` | bool | `False` | Also return the preprocessed RMS envelope | - -### Returns - -- If `return_preprocessed_signal=False` (default): pandas `DataFrame` -- If `return_preprocessed_signal=True`: `(df, xx_r, fs_r)` where `xx_r` is the preprocessed signal and `fs_r` is the downsampled fs - -## Output DataFrame Columns - -| Column | Type | Description | -|--------|------|-------------| -| `start_s` | float | Event start time in seconds | -| `end_s` | float | Event end time in seconds | -| `duration_s` | float | Event duration in seconds | -| `peak_s` | float | Time of peak amplitude in seconds | -| `rel_peak_pos` | float | Peak position within event, 0.0–1.0 | -| `peak_amp_sd` | float | Peak amplitude in standard deviations | - -The DataFrame index holds the channel index for each detected event. - -## Preprocessing pipeline (internal) - -1. **Downsample** to `low_hz * 3` Hz to speed up computation -2. **Common-average subtraction** across channels to reduce EMG artifacts -3. **Bandpass filter** to `[low_hz, high_hz]` -4. **RMS envelope**: square, Hilbert amplitude, Gaussian smooth, square-root -5. **Average across channels**, then **z-score** normalization -6. **Peak detection** at threshold `sd` standard deviations -7. **Edge removal**: drop events within `3 / low_hz` seconds of recording edges - -## Examples - -```python -import scitex as stx - -# Generate synthetic ripple signal (requires ripple_detection package) -xx, tt, fs = stx.dsp.demo_sig(sig_type="ripple", fs=1000, t_sec=10) - -# Detect ripples in 80-140 Hz band -df = stx.dsp.detect_ripples(xx, fs, low_hz=80, high_hz=140) -print(df) -# start_s end_s duration_s peak_s rel_peak_pos peak_amp_sd - -# Custom threshold: require 3 SD and 20 ms minimum duration -df = stx.dsp.detect_ripples( - xx, fs, - low_hz=80, high_hz=140, - sd=3.0, - min_duration_ms=20, -) - -# Get preprocessed envelope for inspection -df, xx_r, fs_r = stx.dsp.detect_ripples( - xx, fs, - low_hz=80, high_hz=140, - return_preprocessed_signal=True, -) -print(f"Preprocessed fs: {fs_r} Hz, shape: {xx_r.shape}") -``` - -## Low-level helper functions - -These are exported from `stx.dsp` for advanced use: - -```python -# Preprocessing (bandpass, RMS, z-score) -xx_r, fs_r = stx.dsp._preprocess(xx, fs, low_hz=80, high_hz=140) - -# Event detection from preprocessed signal -df = stx.dsp._find_events(xx_r, fs_r, sd=2.0, min_duration_ms=10) - -# Drop detections near recording edges -df = stx.dsp._drop_ripples_at_edges(df, low_hz=80, xx_r=xx_r, fs_r=fs_r) - -# Add relative peak position column -df = stx.dsp._calc_relative_peak_position(df) - -# Reorder columns into canonical order -df = stx.dsp._sort_columns(df) -``` - -## Typical ripple band parameters - -| Region | Low Hz | High Hz | -|--------|--------|---------| -| Hippocampus SWR | 80 | 140 | -| Fast ripples | 200 | 400 | -| Sleep spindles | 12 | 15 | diff --git a/src/scitex/dsp/_skills/segmentation.md b/src/scitex/dsp/_skills/segmentation.md deleted file mode 100644 index 4d3b6a2fd..000000000 --- a/src/scitex/dsp/_skills/segmentation.md +++ /dev/null @@ -1,134 +0,0 @@ ---- -description: Sliding-window segmentation, signal cropping, and sktime DataFrame conversion. ---- - -# stx.dsp — Segmentation - -Sources: -- `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_transform.py` -- `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_crop.py` - -## stx.dsp.to_segments - -PyTorch-based sliding-window segmentation using `unfold`. - -```python -windows = stx.dsp.to_segments(x, window_size, overlap_factor=1, dim=-1) -``` - -Decorated with `@torch_fn`. - -### Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `x` | Tensor / ndarray | required | Input signal | -| `window_size` | int | required | Number of samples per window | -| `overlap_factor` | int | `1` | Stride = `window_size // overlap_factor`; `1` = no overlap, `2` = 50% overlap | -| `dim` | int | `-1` | Time dimension to segment | - -### Returns - -Tensor with a new trailing dimension: `(..., n_windows, window_size)`. - -### Example - -```python -import scitex as stx - -xx, tt, fs = stx.dsp.demo_sig() -# xx.shape: (8, 19, 2048) - -# Non-overlapping 256-sample windows -segments = stx.dsp.to_segments(xx, window_size=256) -# segments.shape: (8, 19, n_windows, 256) - -# 50% overlapping windows -segments_50 = stx.dsp.to_segments(xx, window_size=256, overlap_factor=2) -``` - -## stx.dsp.crop - -NumPy-based signal cropping into windows. More flexible than `to_segments`: works on any axis and returns an optional time array. - -```python -cropped = stx.dsp.crop(sig_2d, window_length, overlap_factor=0.0, axis=-1, time=None) -# or -cropped, cropped_times = stx.dsp.crop(sig_2d, window_length, overlap_factor=0.5, time=tt) -``` - -### Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `sig_2d` | ndarray | required | Signal array, any dimensionality | -| `window_length` | int | required | Window length in samples | -| `overlap_factor` | float | `0.0` | Fraction overlap between windows (0.0 = no overlap, 0.5 = 50%) | -| `axis` | int | `-1` | Axis to crop along | -| `time` | ndarray or None | `None` | Time vector matching length along `axis` | - -### Returns - -- `cropped_windows`: shape `(n_windows, *original_shape_with_window_axis)` -- If `time` is given: `(cropped_windows, cropped_times)` where `cropped_times.shape = (n_windows, window_length)` - -### Example - -```python -import scitex as stx -import numpy as np - -FS = 128 -sig2d = np.random.rand(19, FS * 13) # 19 channels, 13 seconds -time = np.arange(sig2d.shape[-1]) / FS - -window_pts = FS * 2 # 2-second windows - -# 50% overlapping crop -xx, tt = stx.dsp.crop(sig2d, window_pts, overlap_factor=0.5, time=time) -# xx.shape: (n_windows, 19, 256) -# tt.shape: (n_windows, 256) - -# No time output -xx = stx.dsp.crop(sig2d, window_pts, overlap_factor=0.5) -``` - -## stx.dsp.to_sktime_df - -Convert a 3D array to sktime-compatible DataFrame format (nested DataFrames). - -```python -df = stx.dsp.to_sktime_df(arr) -``` - -### Parameters - -| Parameter | Type | Description | -|-----------|------|-------------| -| `arr` | ndarray | Shape `(n_samples, seq_len, n_channels)` — note axis order | - -### Returns - -pandas `DataFrame` with one column (`dim_0`), where each cell contains a `pd.Series` of all channels for that sample. - -### Example - -```python -import scitex as stx -import numpy as np - -arr = np.random.randn(100, 256, 4) # 100 samples, 256 timepoints, 4 channels -sktime_df = stx.dsp.to_sktime_df(arr) -# sktime_df.shape: (100, 1) -# sktime_df.iloc[0, 0]: Series with channel_0, channel_1, channel_2, channel_3 -``` - -## Comparison: `crop` vs `to_segments` - -| | `crop` | `to_segments` | -|-|--------|---------------| -| Backend | NumPy | PyTorch | -| Time vector | Supported | Not supported | -| Overlap spec | Float fraction (0.0–1.0) | Integer factor | -| Input dims | Any | `dim` parameter | -| Decorator | None | `@torch_fn` | diff --git a/src/scitex/dsp/_skills/spectral.md b/src/scitex/dsp/_skills/spectral.md deleted file mode 100644 index 4e2a6f6cf..000000000 --- a/src/scitex/dsp/_skills/spectral.md +++ /dev/null @@ -1,94 +0,0 @@ ---- -description: Power spectral density and per-band average power for multi-channel signals. ---- - -# stx.dsp — Spectral Analysis - -Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_psd.py` - -## psd - -```python -psd_vals, freqs = stx.dsp.psd(x, fs, prob=False, dim=-1) -``` - -Computes the power spectral density using the PyTorch `PSD` module from `scitex.nn._PSD`. - -### Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `x` | ndarray / Tensor | required | Input signal, shape `(batch, chs, time)` | -| `fs` | float | required | Sampling frequency in Hz | -| `prob` | bool | `False` | If `True`, normalize PSD to sum to 1 (probability distribution) | -| `dim` | int | `-1` | Time dimension | - -### Returns - -- `psd_vals`: power spectrum, shape `(batch, chs, n_freqs)`, log-scaled -- `freqs`: frequency axis array, shape `(n_freqs,)`, in Hz - -Requires `torch`. - -## band_powers - -```python -avg_powers = stx.dsp.band_powers(self, psd) -``` - -Computes average power within specified frequency bands from an existing PSD. - -Note: `band_powers` as exposed in `__init__.py` is a lower-level function that requires `self` (a PSD instance) and a pre-computed `psd` tensor. It is typically used internally or via the `PSD` class directly. - -## Built-in Frequency Bands - -Predefined bands are available in `stx.dsp.params.BANDS`: - -```python -import scitex as stx - -print(stx.dsp.params.BANDS) -# delta theta lalpha halpha beta gamma -# low_hz 0.5 4.0 8.0 10.0 13.0 32.0 -# high_hz 4.0 8.0 10.0 13.0 32.0 75.0 -``` - -## Examples - -```python -import scitex as stx - -xx, tt, fs = stx.dsp.demo_sig(sig_type="chirp", fs=512, t_sec=4) - -# Compute PSD -psd_vals, freqs = stx.dsp.psd(xx, fs) -# psd_vals.shape: (8, 19, n_freqs) -# freqs.shape: (n_freqs,) - -# Normalized PSD (sums to 1 per channel) -psd_prob, freqs = stx.dsp.psd(xx, fs, prob=True) - -# Plot PSD for first batch/channel -import matplotlib.pyplot as plt -fig, ax = plt.subplots() -ax.plot(freqs, psd_vals[0, 0]) -ax.set_xlabel("Frequency [Hz]") -ax.set_ylabel("log(Power [uV^2 / Hz])") -``` - -## Full pipeline example - -```python -import scitex as stx -import numpy as np - -xx, tt, fs = stx.dsp.demo_sig(fs=512, t_sec=2) - -# Filter to gamma band first, then compute PSD -gamma = stx.dsp.filt.bandpass(xx, fs, np.array([[32, 75]])) -psd_gamma, freqs = stx.dsp.psd(gamma, fs) - -# Check which frequency has peak power in first batch/channel -peak_idx = psd_gamma[0, 0].argmax() -print(f"Peak frequency: {freqs[peak_idx]:.1f} Hz") -``` diff --git a/src/scitex/dsp/_skills/utils.md b/src/scitex/dsp/_skills/utils.md deleted file mode 100644 index 184329525..000000000 --- a/src/scitex/dsp/_skills/utils.md +++ /dev/null @@ -1,175 +0,0 @@ ---- -description: Internal helpers — zero-padding, FIR filter design, differentiable bandpass filter banks. ---- - -# stx.dsp.utils — Utilities - -Source directory: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/utils/` - -These utilities are used internally by the higher-level DSP functions but can also be called directly. - -## Zero-padding - -Source: `utils/_zero_pad.py` - -```python -from scitex.dsp.utils import zero_pad, _zero_pad_1d -``` - -### `_zero_pad_1d(x, target_length)` - -Zero-pad a 1D tensor to a target length, padding symmetrically. - -```python -from scitex.dsp.utils import _zero_pad_1d -import torch - -x = torch.tensor([1.0, 2.0, 3.0]) -padded = _zero_pad_1d(x, target_length=7) -# tensor([0., 1., 2., 3., 0., 0., 0.]) — 2 left, 2 right -``` - -### `zero_pad(xs, dim=0)` - -Zero-pad a list of variable-length tensors/arrays to the same length and stack them. - -```python -from scitex.dsp.utils import zero_pad -import torch - -xs = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0, 5.0])] -stacked = zero_pad(xs, dim=0) -# tensor([[1., 2., 0.], -# [3., 4., 5.]]) -``` - -Accepts NumPy arrays, converts them to tensors automatically. - -## FIR filter design - -Source: `utils/filter.py` - -```python -from scitex.dsp.utils.filter import design_filter, plot_filter_responses -``` - -### `design_filter(sig_len, fs, low_hz=None, high_hz=None, cycle=3, is_bandstop=False)` - -Design an FIR filter using `scipy.signal.firwin` with Hamming window. Returns filter coefficient array. - -Decorated with `@numpy_fn` (converts tensors to arrays automatically). - -| Parameter | Description | -|-----------|-------------| -| `sig_len` | Signal length (determines maximum filter order) | -| `fs` | Sampling frequency | -| `low_hz` | Low cutoff (omit for highpass) | -| `high_hz` | High cutoff (omit for lowpass) | -| `cycle` | Number of cycles at lowest frequency; determines filter order | -| `is_bandstop` | `True` for bandstop when both `low_hz` and `high_hz` given | - -Filter type selection: -- `low_hz` only → lowpass -- `high_hz` only → highpass -- both → bandpass (default) or bandstop (`is_bandstop=True`) - -```python -from scitex.dsp.utils.filter import design_filter, plot_filter_responses -import scitex as stx - -xx, tt, fs = stx.dsp.demo_sig() -seq_len = xx.shape[-1] - -bp = design_filter(seq_len, fs, low_hz=30, high_hz=70) -lp = design_filter(seq_len, fs, low_hz=30) -hp = design_filter(seq_len, fs, high_hz=70) -bs = design_filter(seq_len, fs, low_hz=30, high_hz=70, is_bandstop=True) -``` - -### `plot_filter_responses(filter, fs, worN=8000, title=None)` - -Plot impulse response and frequency response of an FIR filter. Returns a matplotlib `Figure`. - -```python -fig = plot_filter_responses(bp, fs, title="Bandpass 30-70 Hz") -``` - -## Differentiable bandpass filter banks (for gradient-based optimization) - -Source: `utils/_differential_bandpass_filters.py` - -```python -from scitex.dsp.utils import init_bandpass_filters, build_bandpass_filters -``` - -These build learnable PAC filter banks whose center frequencies can be optimized via backpropagation. - -Requires `torchaudio.prototype.functional.sinc_impulse_response`. - -### `init_bandpass_filters(sig_len, fs, pha_low_hz, pha_high_hz, pha_n_bands, amp_low_hz, amp_high_hz, amp_n_bands, cycle)` - -Initialize a filter bank with learnable `pha_mids` and `amp_mids` parameters. - -```python -from scitex.dsp.utils import init_bandpass_filters -import scitex as stx - -xx, tt, fs = stx.dsp.demo_sig(fs=1024) -filters, pha_mids, amp_mids = init_bandpass_filters( - sig_len=xx.shape[-1], - fs=fs, - pha_low_hz=2, pha_high_hz=20, pha_n_bands=30, - amp_low_hz=60, amp_high_hz=160, amp_n_bands=50, -) -# filters: stacked impulse responses shape (pha_n_bands + amp_n_bands, filter_len) -# pha_mids, amp_mids: nn.Parameter — gradients flow through these - -# Verify gradients work -filters.sum().backward() -print(pha_mids.grad) # not None -``` - -### `build_bandpass_filters(sig_len, fs, pha_mids, amp_mids, cycle)` - -Rebuild filter bank from (updated) center frequency parameters. Call this in the forward pass after optimizer.step() to apply learned frequencies. - -```python -from scitex.dsp.utils import build_bandpass_filters - -# After optimizer step: -new_filters = build_bandpass_filters(sig_len, fs, pha_mids, amp_mids, cycle=3) -``` - -## ensure_3d - -Source: `utils/_ensure_3d.py` (also available as `stx.dsp.ensure_3d`) - -```python -x_3d = stx.dsp.ensure_3d(x) -``` - -Promotes 1D `(time,)` or 2D `(batch, time)` tensors to 3D `(batch, chs, time)` for compatibility with all DSP functions. - -```python -import torch -import scitex as stx - -x1d = torch.randn(512) -x2d = torch.randn(8, 512) -x3d = torch.randn(8, 19, 512) - -stx.dsp.ensure_3d(x1d).shape # (1, 1, 512) -stx.dsp.ensure_3d(x2d).shape # (8, 1, 512) -stx.dsp.ensure_3d(x3d).shape # (8, 19, 512) — unchanged -``` - -## stx.dsp.time - -Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_time.py` - -Generate a time vector using `stx.gen.float_linspace`. - -```python -t = stx.dsp.time(start_sec=0, end_sec=5, fs=256) -# Returns array of (end_sec - start_sec) * fs evenly spaced values -``` diff --git a/src/scitex/dsp/_skills/wavelet.md b/src/scitex/dsp/_skills/wavelet.md deleted file mode 100644 index 45fb4e51a..000000000 --- a/src/scitex/dsp/_skills/wavelet.md +++ /dev/null @@ -1,119 +0,0 @@ ---- -description: Continuous wavelet transform returning time-frequency phase and amplitude. ---- - -# stx.dsp.wavelet — Wavelet Transform - -Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_wavelet.py` - -## Signature - -```python -pha, amp, freqs = stx.dsp.wavelet( - x, - fs, - freq_scale="linear", - out_scale="linear", - device="cuda", - batch_size=32, -) -``` - -Computes a continuous wavelet transform (CWT) using the `Wavelet` module from `scitex.nn._Wavelet`. The function is decorated with both `@signal_fn` and `@batch_fn`, so it handles type conversion and automatic batch splitting when input is too large for GPU memory. - -Requires `torch`. - -### Parameters - -| Parameter | Type | Default | Description | -|-----------|------|---------|-------------| -| `x` | ndarray / Tensor | required | Signal, shape `(batch, chs, time)` | -| `fs` | float | required | Sampling frequency in Hz | -| `freq_scale` | str | `"linear"` | Frequency axis spacing: `"linear"` or `"log"` | -| `out_scale` | str | `"linear"` | Output amplitude scale: `"linear"` or `"log"` | -| `device` | str | `"cuda"` | PyTorch device | -| `batch_size` | int | `32` | Batch size for the `@batch_fn` wrapper | - -### Returns - -- `pha`: instantaneous phase, shape `(batch, chs, n_freqs, time)` -- `amp`: amplitude envelope, shape `(batch, chs, n_freqs, time)` -- `freqs`: frequency axis, shape `(batch, chs, n_freqs)` — take `freqs[0, 0]` for the 1D array - -## Examples - -```python -import scitex as stx -import numpy as np - -xx, tt, fs = stx.dsp.demo_sig( - sig_type="chirp", batch_size=4, n_chs=2, fs=512, t_sec=4 -) - -# Compute wavelet transform -pha, amp, freqs = stx.dsp.wavelet(xx, fs, device="cuda") - -# freqs is per-batch/channel; take the 1D version -freqs_1d = freqs[0, 0] # shape: (n_freqs,) - -print(f"pha shape: {pha.shape}") # (4, 2, n_freqs, 2048) -print(f"amp shape: {amp.shape}") # (4, 2, n_freqs, 2048) -print(f"freqs: {freqs_1d}") -``` - -### Log-scale amplitude output - -```python -pha, amp_log, freqs = stx.dsp.wavelet(xx, fs, out_scale="log") -# amp_log contains log(amplitude + 1e-5) — NaN-safe log scaling -``` - -### Spectrogram plot - -```python -import matplotlib.pyplot as plt - -pha, amp, freqs = stx.dsp.wavelet(xx, fs) -freqs_1d = freqs[0, 0].cpu().numpy() -i_batch, i_ch = 0, 0 - -fig, axes = plt.subplots(3, 1, figsize=(10, 8)) - -# Raw signal -axes[0].plot(tt, xx[i_batch, i_ch]) -axes[0].set_ylabel("Amplitude") -axes[0].set_title("Signal") - -# Amplitude spectrogram -log_amp = (amp[i_batch, i_ch] + 1e-5).log().cpu().numpy() -axes[1].imshow(log_amp.T, aspect="auto", origin="lower") -axes[1].set_ylabel("Frequency [Hz]") -axes[1].set_title("Wavelet Amplitude") - -# Phase spectrogram -phase_np = pha[i_batch, i_ch].cpu().numpy() -axes[2].imshow(phase_np.T, aspect="auto", origin="lower") -axes[2].set_ylabel("Frequency [Hz]") -axes[2].set_title("Wavelet Phase [rad]") -axes[2].set_xlabel("Time [s]") -``` - -### Log frequency scale - -```python -pha, amp, freqs = stx.dsp.wavelet(xx, fs, freq_scale="log") -# freqs are logarithmically spaced — more resolution at low frequencies -``` - -## PAC segments - -When input has a segment dimension `(batch, chs, n_segments, time)`, extract one segment first: - -```python -xx, tt, fs = stx.dsp.demo_sig(sig_type="pac", n_segments=20, fs=512, t_sec=4) -# xx.shape: (8, 19, 20, 2048) — has segment dim - -i_segment = 0 -xx_seg = xx[:, :, i_segment, :] # (8, 19, 2048) -pha, amp, freqs = stx.dsp.wavelet(xx_seg, fs) -``` diff --git a/src/scitex/dsp/_time.py b/src/scitex/dsp/_time.py deleted file mode 100755 index 659d143bf..000000000 --- a/src/scitex/dsp/_time.py +++ /dev/null @@ -1,40 +0,0 @@ -#!./env/bin/python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-06-30 12:11:01 (ywatanabe)" -# /mnt/ssd/ripple-wm-code/scripts/externals/scitex/src/scitex/dsp/_time.py - - -import numpy as np - -import scitex - - -def time(start_sec, end_sec, fs): - # return np.linspace(start_sec, end_sec, (end_sec - start_sec) * fs) - return scitex.gen.float_linspace(start_sec, end_sec, (end_sec - start_sec) * fs) - - -def main(): - out = time(10, 15, 256) - print(out) - - -if __name__ == "__main__": - import sys - - import matplotlib.pyplot as plt - - # # Argument Parser - # import argparse - # parser = argparse.ArgumentParser(description='') - # parser.add_argument('--var', '-v', type=int, default=1, help='') - # parser.add_argument('--flag', '-f', action='store_true', default=False, help='') - # args = parser.parse_args() - # Main - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start( - sys, plt, verbose=False - ) - main() - scitex.session.close(CONFIG, verbose=False, notify=False) - -# EOF diff --git a/src/scitex/dsp/_transform.py b/src/scitex/dsp/_transform.py deleted file mode 100755 index caeda6aca..000000000 --- a/src/scitex/dsp/_transform.py +++ /dev/null @@ -1,80 +0,0 @@ -#!/usr/bin/env python3 -# Time-stamp: "2024-04-08 12:41:59 (ywatanabe)"#!/usr/bin/env python3 - - -import numpy as np -import pandas as pd - -try: - import torch - - TORCH_AVAILABLE = True -except ImportError: - TORCH_AVAILABLE = False - torch = None - -from scitex.decorators import torch_fn - - -def _check_torch(): - if not TORCH_AVAILABLE: - raise ImportError( - "PyTorch is not installed. Please install with: pip install torch" - ) - - -def to_sktime_df(arr): - """ - Convert a 3D numpy array into a DataFrame suitable for sktime. - - Parameters: - arr (numpy.ndarray): A 3D numpy array with shape (n_samples, n_channels, seq_len) - - Returns: - pandas.DataFrame: A DataFrame in sktime format - """ - if len(arr.shape) != 3: - raise ValueError("Input data must be a 3D array") - - n_samples, seq_len, n_channels = arr.shape - - # Initialize an empty DataFrame for sktime format - sktime_df = pd.DataFrame(index=range(n_samples), columns=["dim_0"]) - - # Iterate over each sample - for i in range(n_samples): - # Combine all channels into a single cell - combined_series = pd.Series( - {f"channel_{j}": pd.Series(arr[i, :, j]) for j in range(n_channels)} - ) - sktime_df.iloc[i, 0] = combined_series - - return sktime_df - - -@torch_fn -def to_segments(x, window_size, overlap_factor=1, dim=-1): - stride = window_size // overlap_factor - num_windows = (x.size(dim) - window_size) // stride + 1 - windows = x.unfold(dim, window_size, stride) - return windows - - -if __name__ == "__main__": - import scitex - - x, t, f = scitex.dsp.demo_sig() - - y = to_segments(x, 256) - - x = 100 * np.random.rand(16, 160, 1000) - print(_normalize_time(x)) - - x = torch.randn(16, 160, 1000) - print(_normalize_time(x)) - - x = torch.randn(16, 160, 1000).cuda() - print(_normalize_time(x)) - - -# EOF diff --git a/src/scitex/dsp/_wavelet.py b/src/scitex/dsp/_wavelet.py deleted file mode 100755 index 521111fe4..000000000 --- a/src/scitex/dsp/_wavelet.py +++ /dev/null @@ -1,212 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-11-04 02:12:00 (ywatanabe)" -# File: ./scitex_repo/src/scitex/dsp/_wavelet.py - -"""scitex.dsp.wavelet function""" - -import scitex -from scitex.decorators import batch_fn, signal_fn -from scitex.nn._Wavelet import Wavelet - - -# Functions -@signal_fn -@batch_fn -def wavelet( - x, - fs, - freq_scale="linear", - out_scale="linear", - device="cuda", - batch_size=32, -): - m = Wavelet(fs, freq_scale=freq_scale, out_scale="linear").to(device).eval() - pha, amp, freqs = m(x.to(device)) - - if out_scale == "log": - amp = (amp + 1e-5).log() - if amp.isnan().any(): - print("NaN is detected while taking the lograrithm of amplitude.") - - return pha, amp, freqs - - -# @signal_fn -# def wavelet( -# x, -# fs, -# freq_scale="linear", -# out_scale="linear", -# device="cuda", -# batch_size=32, -# ): -# @signal_fn -# def _wavelet( -# x, -# fs, -# freq_scale="linear", -# out_scale="linear", -# device="cuda", -# ): -# m = ( -# Wavelet(fs, freq_scale=freq_scale, out_scale=out_scale) -# .to(device) -# .eval() -# ) -# pha, amp, freqs = m(x.to(device)) - -# if out_scale == "log": -# amp = (amp + 1e-5).log() -# if amp.isnan().any(): -# print( -# "NaN is detected while taking the lograrithm of amplitude." -# ) - -# return pha, amp, freqs - -# if len(x) <= batch_size: -# try: -# pha, amp, freqs = _wavelet( -# x, -# fs, -# freq_scale=freq_scale, -# out_scale=out_scale, -# device=device, -# ) -# torch.cuda.empty_cache() -# return pha, amp, freqs - -# except Exception as e: -# print(e) -# print("\nTrying Batch Mode...") - -# n_batches = (len(x) + batch_size - 1) // batch_size -# device_orig = x.device -# pha, amp, freqs = [], [], [] -# for i_batch in tqdm(range(n_batches)): -# start = i_batch * batch_size -# end = (i_batch + 1) * batch_size -# _pha, _amp, _freqs = _wavelet( -# x[start:end], -# fs, -# freq_scale=freq_scale, -# out_scale=out_scale, -# device=device, -# ) -# torch.cuda.empty_cache() -# # to CPU -# pha.append(_pha.cpu()) -# amp.append(_amp.cpu()) -# freqs.append(_freqs.cpu()) - -# pha = torch.vstack(pha) -# amp = torch.vstack(amp) -# freqs = freqs[0] - -# try: -# pha = pha.to(device_orig) -# amp = amp.to(device_orig) -# freqs = freqs.to(device_orig) -# except Exception as e: -# print( -# f"\nError occurred while transferring wavelet outputs back to the original device. Proceeding with CPU tensor. \n\n({e})" -# ) - -# sleep(0.5) -# torch.cuda.empty_cache() -# return pha, amp, freqs - - -if __name__ == "__main__": - import sys - - import matplotlib.pyplot as plt - import numpy as np - - # Start - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt, agg=True) - - # Parameters - FS = 512 - SIG_TYPE = "chirp" - T_SEC = 4 - - # Demo signal - xx, tt, fs = scitex.dsp.demo_sig( - batch_size=64, - n_chs=19, - n_segments=2, - t_sec=T_SEC, - fs=FS, - sig_type=SIG_TYPE, - ) - - if SIG_TYPE in ["tensorpac", "pac"]: - i_segment = 0 - xx = xx[:, :, i_segment, :] - - # Main - pha, amp, freqs = wavelet(xx, fs, device="cuda") - freqs = freqs[0, 0] - - # Plots - i_batch, i_ch = 0, 0 - fig, axes = scitex.plt.subplots(nrows=3) - - # # Time vector for x-axis extents - # time_extent = [tt.min(), tt.max()] - - # Trace - axes[0].plot(tt, xx[i_batch, i_ch], label=SIG_TYPE) - axes[0].set_ylabel("Amplitude [?V]") - axes[0].legend(loc="upper left") - axes[0].set_title("Signal") - - # Amplitude - # extent = [time_extent[0], time_extent[1], freqs.min(), freqs.max()] - axes[1].imshow2d( - np.log(amp[i_batch, i_ch] + 1e-5).T, - cbar_label="Log(amplitude [?V]) [a.u.]", - aspect="auto", - # extent=extent, - # origin="lower", - ) - axes[1] = scitex.plt.ax.set_ticks(axes[1], x_ticks=tt, y_ticks=freqs) - axes[1].set_ylabel("Frequency [Hz]") - axes[1].set_title("Amplitude") - - # Phase - axes[2].imshow2d( - pha[i_batch, i_ch].T, - cbar_label="Phase [rad]", - aspect="auto", - # extent=extent, - # origin="lower", - ) - axes[2] = scitex.plt.ax.set_ticks(axes[2], x_ticks=tt, y_ticks=freqs) - axes[2].set_ylabel("Frequency [Hz]") - axes[2].set_title("Phase") - - fig.suptitle("Wavelet Transformation") - fig.supxlabel("Time [s]") - - for ax in axes: - ax = scitex.plt.ax.set_n_ticks(ax) - # ax.set_xlim(time_extent[0], time_extent[1]) - - fig.tight_layout(rect=[0, 0.03, 1, 0.95]) - - scitex.io.save(fig, "wavelet.png") - - # Close - scitex.session.close(CONFIG) - -# EOF - -""" -/home/ywatanabe/proj/entrance/scitex/dsp/_wavelet.py -""" - - -# EOF diff --git a/src/scitex/dsp/add_noise.py b/src/scitex/dsp/add_noise.py deleted file mode 100755 index 5de59e9e4..000000000 --- a/src/scitex/dsp/add_noise.py +++ /dev/null @@ -1,128 +0,0 @@ -#!/usr/bin/env python3 -# Time-stamp: "ywatanabe (2024-11-02 23:09:49)" -# File: ./scitex_repo/src/scitex/dsp/add_noise.py - -try: - import torch - - TORCH_AVAILABLE = True -except ImportError: - TORCH_AVAILABLE = False - torch = None - -from scitex.decorators import signal_fn - - -def _check_torch(): - if not TORCH_AVAILABLE: - raise ImportError( - "PyTorch is not installed. Please install with: pip install torch" - ) - - -def _uniform(shape, amp=1.0): - _check_torch() - a, b = -amp, amp - return -amp + (2 * amp) * torch.rand(shape) - - -@signal_fn -def gauss(x, amp=1.0): - noise = amp * torch.randn(x.shape) - return x + noise.to(x.device) - - -@signal_fn -def white(x, amp=1.0): - return x + _uniform(x.shape, amp=amp).to(x.device) - - -@signal_fn -def pink(x, amp=1.0, dim=-1): - """ - Adds pink noise to a given tensor along a specified dimension. - - Parameters: - - x (torch.Tensor): The input tensor to which pink noise will be added. - - amp (float, optional): The amplitude of the pink noise. Defaults to 1.0. - - dim (int, optional): The dimension along which to add pink noise. Defaults to -1. - - Returns: - - torch.Tensor: The input tensor with added pink noise. - """ - cols = x.size(dim) - noise = torch.randn(cols, dtype=x.dtype, device=x.device) - noise = torch.fft.rfft(noise) - indices = torch.arange(1, noise.size(0), dtype=x.dtype, device=x.device) - noise[1:] /= torch.sqrt(indices) - noise = torch.fft.irfft(noise, n=cols) - noise = noise - noise.mean() - noise_amp = torch.sqrt(torch.mean(noise**2)) - noise = noise * (amp / noise_amp) - return x + noise.to(x.device) - - -@signal_fn -def brown(x, amp=1.0, dim=-1): - from scitex.dsp import norm - - noise = _uniform(x.shape, amp=amp) - noise = torch.cumsum(noise, dim=dim) - noise = norm.minmax(noise, amp=amp, dim=dim) - return x + noise.to(x.device) - - -if __name__ == "__main__": - import sys - - import matplotlib.pyplot as plt - - import scitex - - # Start - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) - - # Parameters - T_SEC = 1 - FS = 128 - - # Demo signal - xx, tt, fs = scitex.dsp.demo_sig(t_sec=T_SEC, fs=FS) - - funcs = { - "orig": lambda x: x, - "gauss": gauss, - "white": white, - "pink": pink, - "brown": brown, - } - - # Plots - fig, axes = scitex.plt.subplots(nrows=len(funcs), ncols=2, sharex=True, sharey=True) - count = 0 - for (k, fn), axes_row in zip(funcs.items(), axes): - for ax in axes_row: - if count % 2 == 0: - ax.plot(tt, fn(xx)[0, 0], label=k, c="blue") - else: - ax.plot(tt, (fn(xx) - xx)[0, 0], label=f"{k} - orig", c="red") - count += 1 - ax.legend(loc="upper right") - - fig.supxlabel("Time [s]") - fig.supylabel("Amplitude [?V]") - axes[0, 0].set_title("Signal + Noise") - axes[0, 1].set_title("Noise") - - scitex.io.save(fig, "traces.png") - - # Close - scitex.session.close(CONFIG) - -# EOF - -""" -/home/ywatanabe/proj/entrance/scitex/dsp/add_noise.py -""" - -# EOF diff --git a/src/scitex/dsp/audio.md b/src/scitex/dsp/audio.md deleted file mode 100755 index 9ea6f8214..000000000 --- a/src/scitex/dsp/audio.md +++ /dev/null @@ -1,22 +0,0 @@ - - - -## Audio Support -``` bash -sudo apt remove python3-pyaudio -sudo apt-get install -y libasound2-dev portaudio19-dev libportaudio2 -pip install --no-cache-dir pyaudio - -sudo apt-get install -y alsa-utils -speaker-test -t sine -f 440 - -sudo apt-get update -sudo apt-get install -y pulseaudio -sudo usermod -aG audio $USER -pulseaudio --start - -``` diff --git a/src/scitex/dsp/example.py b/src/scitex/dsp/example.py deleted file mode 100755 index f61ad6df9..000000000 --- a/src/scitex/dsp/example.py +++ /dev/null @@ -1,261 +0,0 @@ -#!/usr/bin/env python3 -# Time-stamp: "2024-04-06 01:36:18 (ywatanabe)" - -import matplotlib - -matplotlib.use("Agg") -import matplotlib.pyplot as plt -import pandas as pd - -import scitex - -# Module-level constants (defaults for example functions) -TGT_FS = 512 -LOW_HZ = 20 -HIGH_HZ = 50 -SIGMA = 10 - -# Default color cycle -CC = {"blue": "#1f77b4", "red": "#d62728", "green": "#2ca02c"} - - -# Functions -def calc_norm_resample_filt_hilbert(xx, tt, fs, sig_type, verbose=True): - sigs = {"index": ("signal", "time", "fs")} # Collector - - if sig_type == "tensorpac": - xx = xx[:, :, 0] - - sigs["orig"] = (xx, tt, fs) - - # Normalization - sigs["z_normed"] = (scitex.dsp.norm.z(xx), tt, fs) - sigs["minmax_normed"] = (scitex.dsp.norm.minmax(xx), tt, fs) - - # Resampling - resampled_xx = scitex.dsp.resample(xx, fs, TGT_FS) - # Create proper time vector for resampled signal - import numpy as np - - resampled_tt = np.linspace(tt[0], tt[-1], resampled_xx.shape[-1]) - sigs["resampled"] = (resampled_xx, resampled_tt, TGT_FS) - - # Noise injection - sigs["gaussian_noise_added"] = (scitex.dsp.add_noise.gauss(xx), tt, fs) - sigs["white_noise_added"] = (scitex.dsp.add_noise.white(xx), tt, fs) - sigs["pink_noise_added"] = (scitex.dsp.add_noise.pink(xx), tt, fs) - sigs["brown_noise_added"] = (scitex.dsp.add_noise.brown(xx), tt, fs) - - # Filtering (bands format is [[low_hz, high_hz]]) - bands = [[LOW_HZ, HIGH_HZ]] - sigs[f"bandpass_filted ({LOW_HZ} - {HIGH_HZ} Hz)"] = ( - scitex.dsp.filt.bandpass(xx, fs, bands), - tt, - fs, - ) - - sigs[f"bandstop_filted ({LOW_HZ} - {HIGH_HZ} Hz)"] = ( - scitex.dsp.filt.bandstop(xx, fs, bands), - tt, - fs, - ) - sigs[f"bandstop_gauss (sigma = {SIGMA})"] = ( - scitex.dsp.filt.gauss(xx, sigma=SIGMA), - tt, - fs, - ) - - # Hilbert Transformation - pha, amp = scitex.dsp.hilbert(xx) - sigs["hilbert_amp"] = (amp, tt, fs) - sigs["hilbert_pha"] = (pha, tt, fs) - - sigs = pd.DataFrame(sigs).set_index("index") - - if verbose: - print(sigs.index) - print(sigs.columns) - - return sigs - - -def plot_signals(plt, sigs, sig_type): - fig, axes = plt.subplots(nrows=len(sigs.columns), sharex=True) - - i_batch = 0 - i_ch = 0 - for ax, (i_col, col) in zip(axes, enumerate(sigs.columns)): - if col == "hilbert_amp": # add the original signal to the ax - _col = "orig" - ( - _xx, - _tt, - _fs, - ) = sigs[_col] - ax.plot(_tt, _xx[i_batch, i_ch], label=_col, c=CC["blue"]) - - # Main - xx, tt, fs = sigs[col] - # if sig_type == "tensorpac": - # xx = xx[:, :, 0] - - # Handle potential shape mismatches from filter operations - signal = xx[i_batch, i_ch] - if hasattr(signal, "squeeze"): - signal = signal.squeeze() - if hasattr(signal, "numpy"): - signal = signal.numpy() - - ax.plot( - tt, - signal, - label=col, - c=CC["red"] if col == "hilbert_amp" else CC["blue"], - ) - - # Adjustments - ax.legend(loc="upper left") - ax.set_xlim(tt[0], tt[-1]) - - ax = scitex.plt.ax.set_n_ticks(ax) - - fig.supxlabel("Time [s]") - fig.supylabel("Voltage") - fig.suptitle(sig_type) - return fig - - -def plot_wavelet(plt, sigs, sig_col, sig_type): - xx, tt, fs = sigs[sig_col] - # if sig_type == "tensorpac": - # xx = xx[:, :, 0] - - # Wavelet Transformation - wavelet_coef, ff_ww = scitex.dsp.wavelet(xx, fs) - - i_batch = 0 - i_ch = 0 - - # Main - fig, axes = plt.subplots(nrows=2, sharex=True) - # Signal - axes[0].plot( - tt, - xx[i_batch, i_ch], - label=sig_col, - c=CC["blue"], - ) - # Adjusts - axes[0].legend(loc="upper left") - axes[0].set_xlim(tt[0], tt[-1]) - axes[0].set_ylabel("Voltage") - axes[0] = scitex.plt.ax.set_n_ticks(axes[0]) - - # Wavelet Spectrogram - axes[1].imshow( - wavelet_coef[i_batch, i_ch], - aspect="auto", - extent=[tt[0], tt[-1], 512, 1], - label="wavelet_coefficient", - ) - # axes[1].set_xlabel("Time [s]") - axes[1].set_ylabel("Frequency [Hz]") - # axes[1].legend(loc="upper left") - axes[1].invert_yaxis() - - fig.supxlabel("Time [s]") - fig.suptitle(sig_type) - - return fig - - -def plot_psd(plt, sigs, sig_col, sig_type): - xx, tt, fs = sigs[sig_col] - - # if sig_type == "tensorpac": - # xx = xx[:, :, 0] - - # Power Spetrum Density - psd, ff_pp = scitex.dsp.psd(xx, fs) - - # Main - i_batch = 0 - i_ch = 0 - fig, axes = plt.subplots(nrows=2, sharex=False) - - # Signal - axes[0].plot( - tt, - xx[i_batch, i_ch], - label=sig_col, - c=CC["blue"], - ) - # Adjustments - axes[0].legend(loc="upper left") - axes[0].set_xlim(tt[0], tt[-1]) - axes[0].set_xlabel("Time [s]") - axes[0].set_ylabel("Voltage") - axes[0] = scitex.plt.ax.set_n_ticks(axes[0]) - - # PSD - axes[1].plot(ff_pp, psd[i_batch, i_ch], label="PSD") - axes[1].set_yscale("log") - axes[1].set_ylabel("Power [uV^2 / Hz]") - axes[1].set_xlabel("Frequency [Hz]") - - fig.suptitle(sig_type) - - return fig - - -if __name__ == "__main__": - # Parameters - T_SEC = 4 - SIG_TYPES = [ - # "uniform", - # "gauss", - # "periodic", - # "chirp", - # "ripple", - # "meg", - "tensorpac", - ] - SRC_FS = 1024 - TGT_FS = 512 - FREQS_HZ = [10, 30, 100] - LOW_HZ = 20 - HIGH_HZ = 50 - SIGMA = 10 - - plt, CC = scitex.plt.configure_mpl(plt, fig_scale=10) - sdir = "/home/ywatanabe/proj/entrance/scitex/dsp/example/" - - for sig_type in SIG_TYPES: - # Demo Signal - xx, tt, fs = scitex.dsp.demo_sig( - t_sec=T_SEC, fs=SRC_FS, freqs_hz=FREQS_HZ, sig_type=sig_type - ) - - # Apply calculations on the original signal - sigs = calc_norm_resample_filt_hilbert(xx, tt, fs, sig_type) - - # Plots signals - fig = plot_signals(plt, sigs, sig_type) - scitex.io.save(fig, sdir + f"{sig_type}/1_signals.png") - - # Plots wavelet coefficients and PSD - for sig_col in sigs.columns: - if "hilbert" in sig_col: - continue - - fig = plot_wavelet(plt, sigs, sig_col, sig_type) - scitex.io.save(fig, sdir + f"{sig_type}/2_wavelet_{sig_col}.png") - - fig = plot_psd(plt, sigs, sig_col, sig_type) - scitex.io.save(fig, sdir + f"{sig_type}/3_psd_{sig_col}.png") - - # plt.show() - - """ - python ./dsp/example.py - """ diff --git a/src/scitex/dsp/filt.py b/src/scitex/dsp/filt.py deleted file mode 100755 index 64289d093..000000000 --- a/src/scitex/dsp/filt.py +++ /dev/null @@ -1,164 +0,0 @@ -#!/usr/bin/env python3 -# Time-stamp: "2024-11-04 02:05:47 (ywatanabe)" -# File: ./scitex_repo/src/scitex/dsp/filt.py - -import numpy as np - -import scitex -from scitex.decorators import signal_fn - -# No top-level imports from nn module to avoid circular dependency -# Filters will be imported inside functions when needed - - -@signal_fn -def gauss(x, sigma, t=None): - from scitex.nn._Filters import GaussianFilter - - return GaussianFilter(sigma)(x, t=t) - - -@signal_fn -def bandpass(x, fs, bands, t=None): - import torch - - from scitex.nn._Filters import BandPassFilter - - # Convert bands to tensor if it's not already - if not isinstance(bands, torch.Tensor): - bands = torch.tensor(bands, dtype=torch.float32) - return BandPassFilter(bands, fs, x.shape[-1])(x, t=t) - - -@signal_fn -def bandstop(x, fs, bands, t=None): - import torch - - from scitex.nn._Filters import BandStopFilter - - # Convert bands to tensor if it's not already - if not isinstance(bands, torch.Tensor): - bands = torch.tensor(bands, dtype=torch.float32) - return BandStopFilter(bands, fs, x.shape[-1])(x, t=t) - - -@signal_fn -def lowpass(x, fs, cutoffs_hz, t=None): - from scitex.nn._Filters import LowPassFilter - - return LowPassFilter(cutoffs_hz, fs, x.shape[-1])(x, t=t) - - -@signal_fn -def highpass(x, fs, cutoffs_hz, t=None): - from scitex.nn._Filters import HighPassFilter - - return HighPassFilter(cutoffs_hz, fs, x.shape[-1])(x, t=t) - - -def _custom_print(x): - print(type(x), x.shape) - - -if __name__ == "__main__": - import sys - - import matplotlib.pyplot as plt - - # Start - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) - - # Parametes - T_SEC = 1 - SRC_FS = 1024 - FREQS_HZ = list(np.linspace(0, 500, 10, endpoint=False).astype(int)) - SIG_TYPE = "periodic" - BANDS = np.vstack([[80, 310]]) - SIGMA = 3 - - # Demo Signal - xx, tt, fs = scitex.dsp.demo_sig( - t_sec=T_SEC, - fs=SRC_FS, - freqs_hz=FREQS_HZ, - sig_type=SIG_TYPE, - ) - - # Filtering - x_bp, t_bp = scitex.dsp.filt.bandpass(xx, fs, BANDS, t=tt) - x_bs, t_bs = scitex.dsp.filt.bandstop(xx, fs, BANDS, t=tt) - x_lp, t_lp = scitex.dsp.filt.lowpass(xx, fs, BANDS[:, 0], t=tt) - x_hp, t_hp = scitex.dsp.filt.highpass(xx, fs, BANDS[:, 1], t=tt) - x_g, t_g = scitex.dsp.filt.gauss(xx, sigma=SIGMA, t=tt) - filted = { - f"Original (Sum of {FREQS_HZ}-Hz signals)": (xx, tt, fs), - f"Bandpass-filtered ({BANDS[0][0]} - {BANDS[0][1]} Hz)": ( - x_bp, - t_bp, - fs, - ), - f"Bandstop-filtered ({BANDS[0][0]} - {BANDS[0][1]} Hz)": ( - x_bs, - t_bs, - fs, - ), - f"Lowpass-filtered ({BANDS[0][0]} Hz)": (x_lp, t_lp, fs), - f"Highpass-filtered ({BANDS[0][1]} Hz)": (x_hp, t_hp, fs), - f"Gaussian-filtered (sigma = {SIGMA} SD [point])": (x_g, t_g, fs), - } - - # Plots traces - fig, axes = plt.subplots(nrows=len(filted), ncols=1, sharex=True, sharey=True) - i_batch = 0 - i_ch = 0 - i_filt = 0 - for ax, (k, v) in zip(axes, filted.items()): - _xx, _tt, _fs = v - if _xx.ndim == 3: - _xx = _xx[i_batch, i_ch] - elif _xx.ndim == 4: - _xx = _xx[i_batch, i_ch, i_filt] - ax.plot(_tt, _xx, label=k) - ax.legend(loc="upper left") - - fig.suptitle("Filtered") - fig.supxlabel("Time [s]") - fig.supylabel("Amplitude") - - scitex.io.save(fig, "traces.png") - - # Calculates and Plots PSD - fig, axes = plt.subplots(nrows=len(filted), ncols=1, sharex=True, sharey=True) - i_batch = 0 - i_ch = 0 - i_filt = 0 - for ax, (k, v) in zip(axes, filted.items()): - _xx, _tt, _fs = v - - _psd, ff = scitex.dsp.psd(_xx, _fs) - if _psd.ndim == 3: - _psd = _psd[i_batch, i_ch] - elif _psd.ndim == 4: - _psd = _psd[i_batch, i_ch, i_filt] - - ax.plot(ff, _psd, label=k) - ax.legend(loc="upper left") - - for bb in np.hstack(BANDS): - ax.axvline(x=bb, color=CC["grey"], linestyle="--") - - fig.suptitle("PSD (power spectrum density) of filtered signals") - fig.supxlabel("Frequency [Hz]") - fig.supylabel("log(Power [uV^2 / Hz]) [a.u.]") - scitex.io.save(fig, "psd.png") - - # Close - scitex.session.close(CONFIG) - -# EOF - -""" -/home/ywatanabe/proj/scitex/src/scitex/dsp/filt.py -""" - -# EOF diff --git a/src/scitex/dsp/norm.py b/src/scitex/dsp/norm.py deleted file mode 100755 index 676898c5d..000000000 --- a/src/scitex/dsp/norm.py +++ /dev/null @@ -1,33 +0,0 @@ -#!/usr/bin/env python3 -# Time-stamp: "2024-04-05 12:15:42 (ywatanabe)" - -try: - import torch as _torch - - TORCH_AVAILABLE = True -except ImportError: - TORCH_AVAILABLE = False - _torch = None - -from scitex.decorators import signal_fn as _signal_fn - - -def _check_torch(): - if not TORCH_AVAILABLE: - raise ImportError( - "PyTorch is not installed. Please install with: pip install torch" - ) - - -@_signal_fn -def z(x, dim=-1): - _check_torch() - return (x - x.mean(dim=dim, keepdim=True)) / x.std(dim=dim, keepdim=True) - - -@_signal_fn -def minmax(x, amp=1.0, dim=-1, fn="mean"): - _check_torch() - MM = x.max(dim=dim, keepdims=True)[0].abs() - mm = x.min(dim=dim, keepdims=True)[0].abs() - return amp * x / _torch.maximum(MM, mm) diff --git a/src/scitex/dsp/params.py b/src/scitex/dsp/params.py deleted file mode 100755 index 2b22e350e..000000000 --- a/src/scitex/dsp/params.py +++ /dev/null @@ -1,51 +0,0 @@ -import numpy as np -import pandas as pd - -BANDS = pd.DataFrame( - data=np.array([[0.5, 4], [4, 8], [8, 10], [10, 13], [13, 32], [32, 75]]).T, - index=["low_hz", "high_hz"], - columns=["delta", "theta", "lalpha", "halpha", "beta", "gamma"], -) - -EEG_MONTAGE_1020 = [ - "FP1", - "F3", - "C3", - "P3", - "O1", - "FP2", - "F4", - "C4", - "P4", - "O2", - "F7", - "T7", - "P7", - "F8", - "T8", - "P8", - "FZ", - "CZ", - "PZ", -] - -EEG_MONTAGE_BIPOLAR_TRANVERSE = [ - # Frontal - "FP1-FP2", - "F7-F3", - "F3-FZ", - "FZ-F4", - "F4-F8", - # Central - "T7-C3", - "C3-CZ", - "CZ-C4", - "C4-T8", - # Parietal - "P7-P3", - "P3-PZ", - "PZ-P4", - "P4-P8", - # Occipital - "O1-O2", -] diff --git a/src/scitex/dsp/reference.py b/src/scitex/dsp/reference.py deleted file mode 100755 index d1c1b1b06..000000000 --- a/src/scitex/dsp/reference.py +++ /dev/null @@ -1,60 +0,0 @@ -#!/usr/bin/env python3 -# Time-stamp: "ywatanabe (2024-11-02 22:48:44)" -# File: ./scitex_repo/src/scitex/dsp/reference.py - -try: - import torch as _torch - - TORCH_AVAILABLE = True -except ImportError: - TORCH_AVAILABLE = False - _torch = None - -from scitex.decorators import torch_fn as _torch_fn - - -def _check_torch(): - if not TORCH_AVAILABLE: - raise ImportError( - "PyTorch is not installed. Please install with: pip install torch" - ) - - -@_torch_fn -def common_average(x, dim=-2): - _check_torch() - re_referenced = (x - x.mean(dim=dim, keepdims=True)) / x.std(dim=dim, keepdims=True) - assert x.shape == re_referenced.shape - return re_referenced - - -@_torch_fn -def random(x, dim=-2): - _check_torch() - idx_all = [slice(None)] * x.ndim - idx_rand_dim = _torch.randperm(x.shape[dim]) - idx_all[dim] = idx_rand_dim - y = x[idx_all] - re_referenced = x - y - assert x.shape == re_referenced.shape - return re_referenced - - -@_torch_fn -def take_reference(x, tgt_indi, dim=-2): - _check_torch() - idx_all = [slice(None)] * x.ndim - idx_all[dim] = tgt_indi - ref = x[tuple(idx_all)].unsqueeze(dim) - re_referenced = x - ref - assert x.shape == re_referenced.shape - return re_referenced - - -if __name__ == "__main__": - import scitex - - x, f, t = scitex.dsp.demo_sig() - y = common_average(x) - -# EOF diff --git a/src/scitex/dsp/template.py b/src/scitex/dsp/template.py deleted file mode 100755 index 08760333d..000000000 --- a/src/scitex/dsp/template.py +++ /dev/null @@ -1,26 +0,0 @@ -#!./env/bin/python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-04-10 15:57:54 (ywatanabe)" - - -# FUnctions - -if __name__ == "__main__": - import sys - - import matplotlib.pyplot as plt - import torch - - import scitex - - # Start - CONFIG, sys.stdout, sys.stderr, plt, cc = scitex.session.start(sys, plt) - - # Close - scitex.session.close(CONFIG) - - """ - /home/ywatanabe/proj/scitex/src/scitex/dsp/template.py - """ - - # EOF diff --git a/src/scitex/dsp/utils/__init__.py b/src/scitex/dsp/utils/__init__.py deleted file mode 100755 index e6e20bdd7..000000000 --- a/src/scitex/dsp/utils/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/env python3 -"""Scitex utils module.""" - -from ._differential_bandpass_filters import ( - build_bandpass_filters, - init_bandpass_filters, -) -from ._ensure_3d import ensure_3d -from ._ensure_even_len import ensure_even_len -from ._zero_pad import _zero_pad_1d, zero_pad - -__all__ = [ - "_zero_pad_1d", - "build_bandpass_filters", - "ensure_3d", - "ensure_even_len", - "init_bandpass_filters", - "zero_pad", -] diff --git a/src/scitex/dsp/utils/_differential_bandpass_filters.py b/src/scitex/dsp/utils/_differential_bandpass_filters.py deleted file mode 100755 index 7b75ec7a7..000000000 --- a/src/scitex/dsp/utils/_differential_bandpass_filters.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 -# Time-stamp: "2024-11-26 22:24:13 (ywatanabe)" -# File: ./scitex_repo/src/scitex/dsp/utils/_differential_bandpass_filters.py - -THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/dsp/utils/_differential_bandpass_filters.py" - -import sys - -import matplotlib.pyplot as plt -import numpy as np - -try: - import torch - import torch.nn as nn - - TORCH_AVAILABLE = True -except ImportError: - TORCH_AVAILABLE = False - torch = None - nn = None - -from scitex.decorators import torch_fn -from scitex.gen._to_even import to_even -from scitex.gen._to_odd import to_odd - -try: - from torchaudio.prototype.functional import sinc_impulse_response - - TORCHAUDIO_AVAILABLE = True -except ImportError: - TORCHAUDIO_AVAILABLE = False - sinc_impulse_response = None - - -def _check_torch(): - if not TORCH_AVAILABLE: - raise ImportError( - "PyTorch is not installed. Please install with: pip install torch" - ) - - -def _check_sinc_available(): - if sinc_impulse_response is None: - raise ImportError( - "sinc_impulse_response requires torchaudio.prototype.functional. " - "Install torchaudio with: pip install torchaudio" - ) - - -# Functions -@torch_fn -def init_bandpass_filters( - sig_len, - fs, - pha_low_hz=2, - pha_high_hz=20, - pha_n_bands=30, - amp_low_hz=60, - amp_high_hz=160, - amp_n_bands=50, - cycle=3, -): - _check_sinc_available() - # Learnable parameters - pha_mids = nn.Parameter(torch.linspace(pha_low_hz, pha_high_hz, pha_n_bands)) - amp_mids = nn.Parameter(torch.linspace(amp_low_hz, amp_high_hz, amp_n_bands)) - filters = build_bandpass_filters(sig_len, fs, pha_mids, amp_mids, cycle) - return filters, pha_mids, amp_mids - - -@torch_fn -def build_bandpass_filters(sig_len, fs, pha_mids, amp_mids, cycle): - _check_sinc_available() - - def _define_freqs(mids, factor): - lows = mids - mids / factor - highs = mids + mids / factor - return lows, highs - - def define_order(low_hz, fs, sig_len, cycle): - order = cycle * int(fs // low_hz) - order = order if 3 * order >= sig_len else (sig_len - 1) // 3 - order = to_even(order) - return order - - def _calc_filters(lows_hz, highs_hz, fs, order): - nyq = fs / 2.0 - order = to_odd(order) - # lowpass filters - irs_ll = sinc_impulse_response(lows_hz / nyq, window_size=order) - irs_hh = sinc_impulse_response(highs_hz / nyq, window_size=order) - irs = irs_ll - irs_hh - return irs - - # Main - pha_lows, pha_highs = _define_freqs(pha_mids, factor=4.0) - amp_lows, amp_highs = _define_freqs(amp_mids, factor=8.0) - - lowest = min(pha_lows.min().item(), amp_lows.min().item()) - order = define_order(lowest, fs, sig_len, cycle) - - pha_bp_filters = _calc_filters(pha_lows, pha_highs, fs, order) - amp_bp_filters = _calc_filters(amp_lows, amp_highs, fs, order) - return torch.vstack([pha_bp_filters, amp_bp_filters]) - - -if __name__ == "__main__": - import scitex - - # Start - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt, agg=True) - - # Demo signal - freqs_hz = [10, 30, 100, 300] - fs = 1024 - xx, tt, fs = scitex.dsp.demo_sig(fs=fs, freqs_hz=freqs_hz) - - # Main - filters, pha_mids, amp_mids = init_bandpass_filters(xx.shape[-1], fs) - - filters.sum().backward() # OK. The filtering bands are trainable with backpropagation. - - # Update 'pha_mids' and 'amp_mids' in the forward method. - # Then, re-build filters using optimized parameters like this: - # self.filters = build_bandpass_filters(self.sig_len, self.fs, self.pha_mids, self.amp_mids, self.cycle) - - mids_all = np.concatenate( - [pha_mids.detach().cpu().numpy(), amp_mids.detach().cpu().numpy()] - ) - - for i_filter in range(len(mids_all)): - mid = mids_all[i_filter] - fig = scitex.dsp.utils.filter.plot_filter_responses( - filters[i_filter].detach().cpu().numpy(), fs, title=f"{mid:.1f} Hz" - ) - scitex.io.save( - fig, - f"differentiable_bandpass_filter_reponses_filter#{i_filter:03d}_{mid:.1f}_Hz.png", - ) - # plt.show() - -# EOF - -""" -python -m scitex.dsp.utils._differential_bandpass_filters -""" - -# EOF diff --git a/src/scitex/dsp/utils/_ensure_3d.py b/src/scitex/dsp/utils/_ensure_3d.py deleted file mode 100755 index 08f48786a..000000000 --- a/src/scitex/dsp/utils/_ensure_3d.py +++ /dev/null @@ -1,18 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-11-05 01:04:03 (ywatanabe)" -# File: ./scitex_repo/src/scitex/dsp/utils/_ensure_3d.py - -from scitex.decorators import torch_fn - - -@torch_fn -def ensure_3d(x): - if x.ndim == 1: # assumes (seq_len,) - x = x.unsqueeze(0).unsqueeze(0) - elif x.ndim == 2: # assumes (batch_siize, seq_len) - x = x.unsqueeze(1) - return x - - -# EOF diff --git a/src/scitex/dsp/utils/_ensure_even_len.py b/src/scitex/dsp/utils/_ensure_even_len.py deleted file mode 100755 index 294e43927..000000000 --- a/src/scitex/dsp/utils/_ensure_even_len.py +++ /dev/null @@ -1,10 +0,0 @@ -#!./env/bin/python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-04-10 11:59:49 (ywatanabe)" - - -def ensure_even_len(x): - if x.shape[-1] % 2 == 0: - return x - else: - return x[..., :-1] diff --git a/src/scitex/dsp/utils/_zero_pad.py b/src/scitex/dsp/utils/_zero_pad.py deleted file mode 100755 index a0e479744..000000000 --- a/src/scitex/dsp/utils/_zero_pad.py +++ /dev/null @@ -1,62 +0,0 @@ -#!/usr/bin/env python3 -# Time-stamp: "2024-11-26 10:30:34 (ywatanabe)" -# File: ./scitex_repo/src/scitex/dsp/utils/_zero_pad.py - -THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/dsp/utils/_zero_pad.py" - -import numpy as np - -try: - import torch - import torch.nn.functional as F - - TORCH_AVAILABLE = True -except ImportError: - TORCH_AVAILABLE = False - torch = None - F = None - - -def _check_torch(): - if not TORCH_AVAILABLE: - raise ImportError( - "PyTorch is not installed. Please install with: pip install torch" - ) - - -def _zero_pad_1d(x, target_length): - """Zero pad a 1D tensor to target length.""" - _check_torch() - if not isinstance(x, torch.Tensor): - x = torch.tensor(x) - padding_needed = target_length - len(x) - padding_left = padding_needed // 2 - padding_right = padding_needed - padding_left - return F.pad(x, (padding_left, padding_right), "constant", 0) - - -def zero_pad(xs, dim=0): - """Zero pad a list of arrays to the same length. - - Args: - xs: List of tensors or arrays - dim: Dimension to stack along - - Returns: - Stacked tensor with zero padding - """ - # Convert to tensors if needed - tensors = [] - for x in xs: - if isinstance(x, np.ndarray): - tensors.append(torch.tensor(x)) - elif isinstance(x, torch.Tensor): - tensors.append(x) - else: - tensors.append(torch.tensor(x)) - - max_len = max([len(x) for x in tensors]) - return torch.stack([_zero_pad_1d(x, max_len) for x in tensors], dim=dim) - - -# EOF diff --git a/src/scitex/dsp/utils/filter.py b/src/scitex/dsp/utils/filter.py deleted file mode 100755 index 572163915..000000000 --- a/src/scitex/dsp/utils/filter.py +++ /dev/null @@ -1,408 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-11-03 07:24:43 (ywatanabe)" -# File: ./scitex_repo/src/scitex/dsp/utils/filter.py - -import matplotlib.pyplot as plt -import numpy as np -from scipy.signal import firwin, freqz - -from scitex.decorators import numpy_fn -from scitex.gen._to_even import to_even - - -@numpy_fn -def design_filter(sig_len, fs, low_hz=None, high_hz=None, cycle=3, is_bandstop=False): - """ - Designs a Finite Impulse Response (FIR) filter based on the specified parameters. - - Arguments: - - sig_len (int): Length of the signal for which the filter is being designed. - - fs (int): Sampling frequency of the signal. - - low_hz (float, optional): Low cutoff frequency for the filter. Required for lowpass and bandpass filters. - - high_hz (float, optional): High cutoff frequency for the filter. Required for highpass and bandpass filters. - - cycle (int, optional): Number of cycles to use in determining the filter order. Defaults to 3. - - is_bandstop (bool, optional): Specifies if the filter should be a bandstop filter. Defaults to False. - - Returns: - - The coefficients of the designed FIR filter. - - Raises: - - FilterParameterError: If the provided parameters are invalid. - """ - - class FilterParameterError(Exception): - """Custom exception for invalid filter parameters.""" - - pass - - def estimate_filter_type(low_hz=None, high_hz=None, is_bandstop=False): - """ - Estimates the filter type based on the provided low and high cutoff frequencies, - and whether a bandstop filter is desired. Raises an exception for invalid configurations. - """ - if low_hz is not None and low_hz < 0: - raise FilterParameterError("low_hz must be non-negative.") - if high_hz is not None and high_hz < 0: - raise FilterParameterError("high_hz must be non-negative.") - if low_hz is not None and high_hz is not None and low_hz >= high_hz: - raise FilterParameterError( - "low_hz must be less than high_hz for valid configurations." - ) - - if low_hz is not None and high_hz is not None: - return "bandstop" if is_bandstop else "bandpass" - elif low_hz is not None: - return "lowpass" - elif high_hz is not None: - return "highpass" - else: - raise FilterParameterError( - "At least one of low_hz or high_hz must be provided." - ) - - def determine_cutoff_frequencies(filter_mode, low_hz, high_hz): - if filter_mode in ["lowpass", "highpass"]: - cutoff = low_hz if filter_mode == "lowpass" else high_hz - else: # 'bandpass' or 'bandstop' - cutoff = [low_hz, high_hz] - return cutoff - - def determine_low_freq(filter_mode, low_hz, high_hz): - if filter_mode in ["lowpass", "bandstop"]: - low_freq = low_hz - else: # 'highpass' or 'bandpass' - low_freq = high_hz if filter_mode == "highpass" else min(low_hz, high_hz) - return low_freq - - def determine_order(filter_mode, fs, low_freq, sig_len, cycle): - order = cycle * int((fs // low_freq)) - if 3 * order < sig_len: - order = (sig_len - 1) // 3 - order = to_even(order) - return order - - fs = int(fs) - low_hz = float(low_hz) if low_hz is not None else low_hz - high_hz = float(high_hz) if high_hz is not None else high_hz - filter_mode = estimate_filter_type(low_hz, high_hz, is_bandstop) - cutoff = determine_cutoff_frequencies(filter_mode, low_hz, high_hz) - low_freq = determine_low_freq(filter_mode, low_hz, high_hz) - order = determine_order(filter_mode, fs, low_freq, sig_len, cycle) - numtaps = order + 1 - - try: - h = firwin( - numtaps=numtaps, - cutoff=cutoff, - pass_zero=(filter_mode in ["highpass", "bandstop"]), - window="hamming", - fs=fs, - scale=True, - ) - except Exception as e: - print(e) - import ipdb - - ipdb.set_trace() - - return h - - -@numpy_fn -def plot_filter_responses(filter, fs, worN=8000, title=None): - """ - Plots the impulse and frequency response of an FIR filter using numpy arrays. - - Parameters: - - filter_coeffs (numpy.ndarray): The filter coefficients as a numpy array. - - fs (int): The sampling frequency in Hz. - - title (str, optional): The title of the plot. Defaults to None. - - Returns: - - matplotlib.figure.Figure: The figure object containing the impulse and frequency response plots. - """ - import scitex - - ww, hh = freqz(filter, worN=worN, fs=fs) - - fig, axes = scitex.plt.subplots(ncols=2) - fig.suptitle(title) - - # Impulse Responses of FIR Filter - ax = axes[0] - ax.plot(filter) - ax.set_title("Impulse Responses of FIR Filter") - ax.set_xlabel("Tap Number") - ax.set_ylabel("Amplitude") - - # Frequency Response of FIR Filter - ax = axes[1] - ax.plot(ww, 20 * np.log10(abs(hh) + 1e-5)) - ax.set_title("Frequency Response of FIR Filter") - ax.set_xlabel("Frequency [Hz]") - ax.set_ylabel("Gain [dB]") - - return fig - - -if __name__ == "__main__": - import scitex - - # Example usage - xx, tt, fs = scitex.dsp.demo_sig() - batch_size, n_chs, seq_len = xx.shape - - lp_filter = design_filter(seq_len, fs, low_hz=30, high_hz=None) - hp_filter = design_filter(seq_len, fs, low_hz=None, high_hz=70) - bp_filter = design_filter(seq_len, fs, low_hz=30, high_hz=70) - bs_filter = design_filter(seq_len, fs, low_hz=30, high_hz=70, is_bandstop=True) - - fig = plot_filter_responses(lp_filter, fs, title="Lowpass Filter") - fig = plot_filter_responses(hp_filter, fs, title="Highpass Filter") - fig = plot_filter_responses(bp_filter, fs, title="Bandpass Filter") - fig = plot_filter_responses(bs_filter, fs, title="Bandstop Filter") - - # Figure - fig, axes = plt.subplots(nrows=4, ncols=2) - - # Time domain expressions?? - axes[0, 0].plot(lp_filter, label="Lowpass Filter") - axes[1, 0].plot(hp_filter, label="Highpass Filter") - axes[2, 0].plot(bp_filter, label="Bandpass Filter") - axes[3, 0].plot(bs_filter, label="Bandstop Filter") - # fig.suptitle("Impulse Responses of FIR Filter") - # fig.supxlabel("Tap Number") - # fig.supylabel("Amplitude") - # fig.show() - - # Frequency response of the filters - w, h_lp = freqz(lp_filter, worN=8000, fs=fs) - w, h_hp = freqz(hp_filter, worN=8000, fs=fs) - w, h_bp = freqz(bp_filter, worN=8000, fs=fs) - w, h_bs = freqz(bs_filter, worN=8000, fs=fs) - - # Plotting the frequency response - axes[0, 1].plot(w, 20 * np.log10(abs(h_lp)), label="Lowpass Filter") - axes[1, 1].plot(w, 20 * np.log10(abs(h_hp)), label="Highpass Filter") - axes[2, 1].plot(w, 20 * np.log10(abs(h_bp)), label="Bandpass Filter") - axes[3, 1].plot(w, 20 * np.log10(abs(h_bs)), label="Bandstop Filter") - # plt.title("Frequency Response of FIR Filters") - # plt.xlabel("Frequency (Hz)") - # plt.ylabel("Gain (dB)") - # plt.grid(True) - # plt.legend(loc="best") - # plt.show() - fig.tight_layout() - plt.show() - -# @torch_fn -# def bandpass(x, filt): -# assert x.ndim == 3 -# xf = F.conv1d( -# x.reshape(-1, x.shape[-1]).unsqueeze(1), -# filt.unsqueeze(0).unsqueeze(0), -# padding="same", -# ).reshape(*x.shape) -# assert x.shape == xf.shape -# return xf - -# def define_bandpass_filters(seq_len, fs, freq_bands, cycle=3): -# """ -# Defines Finite Impulse Response (FIR) filters. -# b: The filter coefficients (or taps) of the FIR filters -# a: The denominator coefficients of the filter's transfer function. However, FIR filters have a transfer function with a denominator equal to 1 (since they are all-zero filters with no poles). -# """ -# # Parameters -# n_freqs = len(freq_bands) -# nyq = fs / 2.0 - -# bs = [] -# for ll, hh in freq_bands: -# wn = np.array([ll, hh]) / nyq -# order = define_fir_order(fs, seq_len, ll, cycle=cycle) -# bs.append(fir1(order, wn)[0]) -# return bs - -# def define_fir_order(fs, sizevec, flow, cycle=3): -# """ -# Calculate filter order. -# """ -# if cycle is None: -# filtorder = 3 * np.fix(fs / flow) -# else: -# filtorder = cycle * (fs // flow) - -# if sizevec < 3 * filtorder: -# filtorder = (sizevec - 1) // 3 - -# return int(filtorder) - -# def n_odd_fcn(f, o, w, l): -# """Odd case.""" -# # Variables : -# b0 = 0 -# m = np.array(range(int(l + 1))) -# k = m[1 : len(m)] -# b = np.zeros(k.shape) - -# # Run Loop : -# for s in range(0, len(f), 2): -# m = (o[s + 1] - o[s]) / (f[s + 1] - f[s]) -# b1 = o[s] - m * f[s] -# b0 = b0 + ( -# b1 * (f[s + 1] - f[s]) -# + m / 2 * (f[s + 1] * f[s + 1] - f[s] * f[s]) -# ) * abs(np.square(w[round((s + 1) / 2)])) -# b = b + ( -# m -# / (4 * np.pi * np.pi) -# * ( -# np.cos(2 * np.pi * k * f[s + 1]) -# - np.cos(2 * np.pi * k * f[s]) -# ) -# / (k * k) -# ) * abs(np.square(w[round((s + 1) / 2)])) -# b = b + ( -# f[s + 1] * (m * f[s + 1] + b1) * np.sinc(2 * k * f[s + 1]) -# - f[s] * (m * f[s] + b1) * np.sinc(2 * k * f[s]) -# ) * abs(np.square(w[round((s + 1) / 2)])) - -# b = np.insert(b, 0, b0) -# a = (np.square(w[0])) * 4 * b -# a[0] = a[0] / 2 -# aud = np.flipud(a[1 : len(a)]) / 2 -# a2 = np.insert(aud, len(aud), a[0]) -# h = np.concatenate((a2, a[1:] / 2)) - -# return h - -# def n_even_fcn(f, o, w, l): -# """Even case.""" -# # Variables : -# k = np.array(range(0, int(l) + 1, 1)) + 0.5 -# b = np.zeros(k.shape) - -# # # Run Loop : -# for s in range(0, len(f), 2): -# m = (o[s + 1] - o[s]) / (f[s + 1] - f[s]) -# b1 = o[s] - m * f[s] -# b = b + ( -# m -# / (4 * np.pi * np.pi) -# * ( -# np.cos(2 * np.pi * k * f[s + 1]) -# - np.cos(2 * np.pi * k * f[s]) -# ) -# / (k * k) -# ) * abs(np.square(w[round((s + 1) / 2)])) -# b = b + ( -# f[s + 1] * (m * f[s + 1] + b1) * np.sinc(2 * k * f[s + 1]) -# - f[s] * (m * f[s] + b1) * np.sinc(2 * k * f[s]) -# ) * abs(np.square(w[round((s + 1) / 2)])) - -# a = (np.square(w[0])) * 4 * b -# h = 0.5 * np.concatenate((np.flipud(a), a)) - -# return h - -# def firls(n, f, o): -# # Variables definition : -# w = np.ones(round(len(f) / 2)) -# n += 1 -# f /= 2 -# lo = (n - 1) / 2 - -# nodd = bool(n % 2) - -# if nodd: # Odd case -# h = n_odd_fcn(f, o, w, lo) -# else: # Even case -# h = n_even_fcn(f, o, w, lo) - -# return h - -# def fir1(n, wn): -# # Variables definition : -# nbands = len(wn) + 1 -# ff = np.array((0, wn[0], wn[0], wn[1], wn[1], 1)) - -# f0 = np.mean(ff[2:4]) -# lo = n + 1 - -# mags = np.array(range(nbands)).reshape(1, -1) % 2 -# aa = np.ravel(np.tile(mags, (2, 1)), order="F") - -# # Get filter coefficients : -# h = firls(lo - 1, ff, aa) - -# # Apply a window to coefficients : -# wind = np.hamming(lo) -# b = h * wind -# c = np.exp(-1j * 2 * np.pi * (f0 / 2) * np.array(range(lo))) -# b /= abs(c @ b) - -# return b, 1 - -# def apply_filters(x, filts): -# """ -# x: (batch_size, n_chs, seq_len) -# filts: (n_filts, seq_len_filt) -# """ -# assert x.ndims == 3 -# assert filts.ndims == 2 -# batch_size, n_chs, n_time = x.shape -# x = x.reshape(-1, x.shape[-1]).unsqueeze(1) -# filts = filts.unsqueeze(1) -# n_filts = len(filts) -# return F.conv1d(x, filts, padding="same").reshape( -# batch_size, n_chs, n_filts, n_time -# ) - -# if __name__ == "__main__": -# import torch -# import torch.nn.functional as F - -# plt, CC = scitex.plt.configure_mpl(plt) - -# # Demo Signal -# freqs_hz = [10, 30, 100] -# xx, tt, fs = scitex.dsp.demo_sig(freqs_hz=freqs_hz, sig_type="periodic") -# x = xx - -# seq_len = x.shape[-1] -# freq_bands = np.array([[20, 70], [3.0, 4.0]]) - -# # Plots the figure -# fig, ax = scitex.plt.subplots() -# # ax.plot(b, label="bandpass filter") - -# # Bandpass Filtering -# filters = define_bandpass_filters(seq_len, fs, freq_bands, cycle=3) -# i_filt = 0 -# # xf = bandpass(xx, filters[i_filt]) - -# # Plots the signals -# fig, axes = scitex.plt.subplots(nrows=2, sharex=True, sharey=True) -# axes[0].plot(tt, xx[0, 0], label="orig") -# axes[1].plot(tt, xf[0, 0], label="orig") -# [ax.legend(loc="upper left") for ax in axes] - -# # Plots PSDs -# psd_xx, ff_xx = scitex.dsp.psd(xx.numpy(), fs) -# psd_xf, ff_xf = scitex.dsp.psd(xf.numpy(), fs) - -# fig, axes = scitex.plt.subplots(nrows=2, sharex=True, sharey=True) -# axes[0].plot(ff_xx, psd_xx[0, 0], label="orig") -# axes[1].plot(ff_xf, psd_xf[0, 0], label="filted") -# [ax.legend(loc="upper left") for ax in axes] -# plt.show() - -# # Multiple Filters in a parallel computation -# x = torch.randn(33, 32, 30) -# filters = torch.randn(20, 5) - -# y = apply_filters(x, filters) -# print(y.shape) # (33, 32, 20, 30) - -# EOF diff --git a/src/scitex/dsp/utils/pac.py b/src/scitex/dsp/utils/pac.py deleted file mode 100755 index ffbdab6a4..000000000 --- a/src/scitex/dsp/utils/pac.py +++ /dev/null @@ -1,178 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-26 06:26:29 (ywatanabe)" -# File: /ssh:ywatanabe@sp:/home/ywatanabe/proj/scitex_repo/src/scitex/dsp/utils/pac.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/dsp/utils/pac.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -#! ./env/bin/python3 -# Time-stamp: "2024-04-16 17:07:27" - - -""" -This script does XYZ. -""" - -# Imports -import sys - -import matplotlib.pyplot as plt -import numpy as np -import tensorpac - -import scitex - - -# Functions -def calc_pac_with_tensorpac(xx, fs, t_sec, i_batch=0, i_ch=0): - # Morlet's Wavelet Transfrmation - p = tensorpac.Pac(f_pha="hres", f_amp="mres", dcomplex="wavelet") - - # Bandpass Filtering and Hilbert Transformation - phases = p.filter(fs, xx[i_batch, i_ch], ftype="phase", n_jobs=1) # (50, 20, 2048) - amplitudes = p.filter( - fs, xx[i_batch, i_ch], ftype="amplitude", n_jobs=1 - ) # (50, 20, 2048) - - # Calculates xpac - k = 2 - p.idpac = (k, 0, 0) - xpac = p.fit(phases, amplitudes) # (50, 50, 20) - pac = xpac.mean(axis=-1) # (50, 50) - - freqs_amp = p.f_amp.mean(axis=-1) - freqs_pha = p.f_pha.mean(axis=-1) - - pac = pac.T # (amp, pha) -> (pha, amp) - - return phases, amplitudes, freqs_pha, freqs_amp, pac - - -def plot_PAC_scitex_vs_tensorpac(pac_scitex, pac_tp, freqs_pha, freqs_amp): - assert pac_scitex.shape == pac_tp.shape - - # Plots - fig, axes = scitex.plt.subplots(ncols=3) # , sharex=True, sharey=True - - # To align scalebars - vmin = min(np.min(pac_scitex), np.min(pac_tp), np.min(pac_scitex - pac_tp)) - vmax = max(np.max(pac_scitex), np.max(pac_tp), np.max(pac_scitex - pac_tp)) - - # scitex version - ax = axes[0] - ax.imshow2d( - pac_scitex, - cbar=False, - vmin=vmin, - vmax=vmax, - ) - ax.set_title("scitex") - - # Tensorpac - ax = axes[1] - ax.imshow2d( - pac_tp, - cbar=False, - vmin=vmin, - vmax=vmax, - ) - ax.set_title("Tensorpac") - - # Diff. - ax = axes[2] - ax.imshow2d( - pac_scitex - pac_tp, - cbar_label="PAC values", - cbar_shrink=0.5, - vmin=vmin, - vmax=vmax, - ) - ax.set_title(f"Difference\n(scitex - Tensorpac)") - - # for ax in axes: - # ax.set_ticks( - # x_vals=freqs_pha, - # # y_vals=freqs_amp, - # ) - # # ax.set_n_ticks() - - fig.suptitle("PAC (MI) values") - fig.supxlabel("Frequency for phase [Hz]") - fig.supylabel("Frequency for amplitude [Hz]") - - return fig - - -# Snake_case alias for consistency -def plot_pac_scitex_vs_tensorpac(pac_scitex, pac_tp, freqs_pha, freqs_amp): - """ - Plot comparison between SciTeX and Tensorpac phase-amplitude coupling results. - - This is an alias for plot_PAC_scitex_vs_tensorpac with snake_case naming. - - Parameters - ---------- - pac_scitex : array-like - PAC values from SciTeX - pac_tp : array-like - PAC values from Tensorpac - freqs_pha : array-like - Phase frequencies - freqs_amp : array-like - Amplitude frequencies - - Returns - ------- - fig : matplotlib.figure.Figure - The generated figure - """ - return plot_PAC_scitex_vs_tensorpac(pac_scitex, pac_tp, freqs_pha, freqs_amp) - - -if __name__ == "__main__": - import torch - - # Start - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) - - # Parameters - FS = 512 - T_SEC = 4 - - xx, tt, fs = scitex.dsp.demo_sig( - batch_size=2, - n_chs=2, - n_segments=2, - fs=FS, - t_sec=T_SEC, - sig_type="tensorpac", - ) - - # scitex - pac_scitex, freqs_pha, freqs_amp = scitex.dsp.pac( - xx, fs, batch_size=2, pha_n_bands=50, amp_n_bands=30 - ) - i_batch, i_epoch = 0, 0 - pac_scitex = pac_scitex[i_batch, i_epoch] - - # Tensorpac - phases, amplitudes, freqs_pha, freqs_amp, pac_tp = calc_pac_with_tensorpac( - xx, fs, T_SEC, i_batch=0, i_ch=0 - ) - - # Plots - fig = plot_PAC_scitex_vs_tensorpac(pac_scitex, pac_tp, freqs_pha, freqs_amp) - plt.show() - - # Close - scitex.session.close(CONFIG) - -""" -/home/ywatanabe/proj/entrance/scitex/dsp/utils/pac.py -""" - -# EOF