diff --git a/pyproject.toml b/pyproject.toml
index bdfecd610..22c164250 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -294,25 +294,7 @@ dict = ["scitex-dict>=0.1.2"]
# DSP Module - Digital Signal Processing
# Use: pip install scitex[dsp]
-dsp = [
- "scipy",
- "sounddevice",
- "matplotlib",
- "joblib",
- "tensorpac",
- "ruamel.yaml",
- "h5py",
- "readchar",
- "xarray",
- "seaborn",
- # # Heavy dependencies handled by _AVAILABLE flags
- # "mne",
- # "torch",
- # "ripple_detection",
- # "torchsummary",
- # "julius",
- # "torchaudio",
-]
+dsp = ["scitex-dsp>=0.1.0"]
# DevTools Module - scitex.dev code analysis and development tools
# Use: pip install scitex[devtools]
diff --git a/src/scitex/dsp/README.md b/src/scitex/dsp/README.md
deleted file mode 100755
index 4a7ad35d9..000000000
--- a/src/scitex/dsp/README.md
+++ /dev/null
@@ -1,147 +0,0 @@
-
-
-
-# [`scitex.dsp`](https://github.com/ywatanabe1989/scitex/tree/main/src/scitex/dsp/)
-
-## Overview
-The `scitex.dsp` module provides Digital Signal Processing (DSP) utilities written in **PyTorch**, optimized for **CUDA** devices when available. This module offers efficient implementations of various DSP algorithms and techniques.
-
-## Installation
-```bash
-pip install scitex
-```
-
-## Features
-- PyTorch-based implementations for GPU acceleration
-- Wavelet transforms and analysis
-- Filtering operations (e.g., bandpass, lowpass, highpass)
-- Spectral analysis tools
-- Time-frequency analysis utilities
-- Signal generation and manipulation functions
-- Phase-Amplitude Coupling (PAC) analysis
-- Modulation Index calculation
-- Hilbert transform
-- Power Spectral Density (PSD) estimation
-- Resampling utilities
-
-## Galleries
-
-
-
-
-## Quick Start
-```python
-# Parameters
-SRC_FS = 1024 # Source sampling frequency
-TGT_FS = 512 # Target sampling frequency
-FREQS_HZ = [10, 30, 100] # Frequencies in Hz for periodic signals
-LOW_HZ = 20 # Low frequency for bandpass filter
-HIGH_HZ = 50 # High frequency for bandpass filter
-SIGMA = 10 # Sigma for Gaussian filter
-SIG_TYPES = [
- "uniform",
- "gauss",
- "periodic",
- "chirp",
- "ripple",
- "meg",
- "tensorpac",
-] # Available signal types
-
-
-# Demo Signal
-xx, tt, fs = scitex.dsp.demo_sig(
- t_sec=T_SEC, fs=SRC_FS, freqs_hz=FREQS_HZ, sig_type="chirp"
-)
-# xx.shape (batch_size, n_chs, seq_len)
-# xx.shape (batch_size, n_chs, n_segments, seq_len) # when sig_type is "tensorpac" or "pac"
-
-
-# # Various data types are automatically handled:
-# xx = torch.tensor(xx).float()
-# xx = torch.tensor(xx).float().cuda()
-# xx = np.array(xx)
-# xx = pd.DataFrame(xx)
-
-# Normalization
-xx_norm = scitex.dsp.norm.z(xx)
-xx_minmax = scitex.dsp.norm.minmax(xx)
-
-# Resampling
-xx_resampled = scitex.dsp.resample(xx, fs, TGT_FS)
-
-# Noise addition
-xx_gauss = scitex.dsp.add_noise.gauss(xx)
-xx_white = scitex.dsp.add_noise.white(xx)
-xx_pink = scitex.dsp.add_noise.pink(xx)
-xx_brown = scitex.dsp.add_noise.brown(xx)
-
-# Filtering
-xx_filted_bandpass = scitex.dsp.filt.bandpass(xx, fs, low_hz=LOW_HZ, high_hz=HIGH_HZ)
-xx_filted_bandstop = scitex.dsp.filt.bandstop(xx, fs, low_hz=LOW_HZ, high_hz=HIGH_HZ)
-xx_filted_gauss = scitex.dsp.filt.gauss(xx, sigma=SIGMA)
-
-# Hilbert Transformation
-phase, amplitude = scitex.dsp.hilbert(xx) # or envelope
-
-# Wavelet Transformation
-wavelet_coef, wavelet_freqs = scitex.dsp.wavelet(xx, fs)
-
-# Power Spetrum Density
-psd, psd_freqs = scitex.dsp.psd(xx, fs)
-
-# Phase-Amplitude Coupling
-pac, freqs_pha, freqs_amp = scitex.dsp.pac(xx, fs) # This process is computationally intensive. Please monitor RAM/VRAM usage.
-```
-
-## API Reference
-- `scitex.dsp.wavelet_transform(signal, wavelet, level)`: Performs wavelet transform
-- `scitex.dsp.bandpass_filter(signal, lowcut, highcut, fs)`: Applies bandpass filter
-- `scitex.dsp.lowpass_filter(signal, cutoff, fs)`: Applies lowpass filter
-- `scitex.dsp.highpass_filter(signal, cutoff, fs)`: Applies highpass filter
-- `scitex.dsp.spectrogram(signal, fs, nperseg)`: Computes spectrogram
-- `scitex.dsp.stft(signal, fs, nperseg)`: Performs Short-Time Fourier Transform
-- `scitex.dsp.istft(stft, fs, nperseg)`: Performs Inverse Short-Time Fourier Transform
-- `scitex.dsp.chirp(t, f0, f1, method)`: Generates chirp signal
-- `scitex.dsp.hilbert(signal)`: Performs Hilbert transform
-- `scitex.dsp.phase_amplitude_coupling(signal, fs)`: Calculates Phase-Amplitude Coupling
-- `scitex.dsp.Modulation_index(signal, fs)`: Computes Modulation Index
-- `scitex.dsp.psd(signal, fs)`: Estimates Power Spectral Density
-- `scitex.dsp.resample(signal, orig_fs, new_fs)`: Resamples signal to new frequency
-- `scitex.dsp.add_noise(signal, snr)`: Adds noise to signal with specified SNR
-
-## Use Cases
-- Audio signal processing
-- Biomedical signal analysis
-- Vibration analysis
-- Communication systems
-- Radar and sonar signal processing
-- Neuroscience data analysis
-
-## Performance
-The `scitex.dsp` module leverages PyTorch's GPU acceleration capabilities, providing significant speedups for large-scale signal processing tasks when run on CUDA-enabled devices.
-
-## Contributing
-Contributions to improve `scitex.dsp` are welcome. Please submit pull requests or open issues on the GitHub repository.
-
-## License
-This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.
-
-## Contact
-Yusuke Watanabe (ywatanabe@scitex.ai)
-
-For more information and updates, please visit the [scitex GitHub repository](https://github.com/ywatanabe1989/scitex).
diff --git a/src/scitex/dsp/__init__.py b/src/scitex/dsp/__init__.py
index 2e6f66f6f..e8e37f736 100755
--- a/src/scitex/dsp/__init__.py
+++ b/src/scitex/dsp/__init__.py
@@ -1,86 +1,13 @@
-#!/usr/bin/env python3
-"""Scitex dsp module."""
+"""SciTeX dsp — thin compatibility shim for scitex-dsp."""
-import warnings
+import sys as _sys
-# Import example, params, norm, reference, filt, and add_noise modules as submodules
-from . import add_noise, example, filt, norm, params, reference
-
-# Core imports that should always work
-from ._crop import crop
-from ._demo_sig import demo_sig
-from ._detect_ripples import (
- _calc_relative_peak_position,
- _drop_ripples_at_edges,
- _find_events,
- _preprocess,
- _sort_columns,
- detect_ripples,
-)
-from ._ensure_3d import ensure_3d
-from ._hilbert import hilbert
-from ._modulation_index import _reshape, modulation_index
-from ._pac import pac
-from ._psd import band_powers, psd
-from ._resample import resample
-from ._time import time
-from ._transform import to_segments, to_sktime_df
-from ._wavelet import wavelet
-
-# Try to import audio-related functions that require PortAudio
try:
- from ._listen import list_and_select_device
-
- _audio_available = True
-except (ImportError, OSError):
- warnings.warn(
- "Audio functionality unavailable: PortAudio library not found. "
- "Install PortAudio to use audio features (e.g., sudo apt-get install portaudio19-dev)",
- ImportWarning,
- )
- list_and_select_device = None
- _audio_available = False
-
-# Try to import MNE-related functions
-try:
- from ._mne import get_eeg_pos
-
- _mne_available = True
-except ImportError:
- warnings.warn(
- "MNE functionality unavailable. Install MNE-Python to use EEG position features.",
- ImportWarning,
- )
- get_eeg_pos = None
- _mne_available = False
-
-__all__ = [
- "_calc_relative_peak_position",
- "_drop_ripples_at_edges",
- "_find_events",
- "_preprocess",
- "_reshape",
- "_sort_columns",
- "add_noise",
- "band_powers",
- "crop",
- "demo_sig",
- "detect_ripples",
- "ensure_3d",
- "example",
- "filt",
- "get_eeg_pos",
- "hilbert",
- "list_and_select_device",
- "modulation_index",
- "norm",
- "pac",
- "params",
- "psd",
- "reference",
- "resample",
- "time",
- "to_segments",
- "to_sktime_df",
- "wavelet",
-]
+ import scitex_dsp as _real
+except ImportError as _e:
+ raise ImportError(
+ "scitex.dsp requires the 'scitex-dsp' package. "
+ "Install with: pip install scitex[dsp] (or: pip install scitex-dsp)"
+ ) from _e
+
+_sys.modules[__name__] = _real
diff --git a/src/scitex/dsp/_crop.py b/src/scitex/dsp/_crop.py
deleted file mode 100755
index a651387cf..000000000
--- a/src/scitex/dsp/_crop.py
+++ /dev/null
@@ -1,125 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-# Time-stamp: "ywatanabe (2024-11-02 22:50:46)"
-# File: ./scitex_repo/src/scitex/dsp/_crop.py
-
-import numpy as np
-
-
-def crop(sig_2d, window_length, overlap_factor=0.0, axis=-1, time=None):
- """
- Crops the input signal into overlapping windows of a specified length,
- allowing for an arbitrary axis and considering a time vector.
-
- Parameters:
- - sig_2d (numpy.ndarray): The input sig_2d array to be cropped. Can be multi-dimensional.
- - window_length (int): The length of each window to crop the sig_2d into.
- - overlap_factor (float): The fraction of the window that consecutive windows overlap. For example, an overlap_factor of 0.5 means 50% overlap.
- - axis (int): The time axis along which to crop the sig_2d.
- - time (numpy.ndarray): The time vector associated with the signal. Its length should match the signal's length along the cropping axis.
-
- Returns:
- - cropped_windows (numpy.ndarray): The cropped signal windows. The shape depends on the input shape and the specified axis.
- """
- # Ensure axis is in a valid range
- if axis < 0:
- axis += sig_2d.ndim
- if axis >= sig_2d.ndim or axis < 0:
- raise ValueError("Invalid axis. Axis out of range for sig_2d dimensions.")
-
- if time is not None:
- # Validate the length of the time vector against the signal's dimension
- if sig_2d.shape[axis] != len(time):
- raise ValueError(
- "Length of time vector does not match signal's dimension along the specified axis."
- )
-
- # Move the target axis to the last position
- axes = np.arange(sig_2d.ndim)
- axes[axis], axes[-1] = axes[-1], axes[axis]
- sig_2d_permuted = np.transpose(sig_2d, axes)
-
- # Compute the number of windows and the step size
- seq_len = sig_2d_permuted.shape[-1]
- step = int(window_length * (1 - overlap_factor))
- n_windows = max(
- 1, ((seq_len - window_length) // step + 1)
- ) # Ensure at least 1 window
-
- # Crop the sig_2d into windows
- cropped_windows = []
- cropped_times = []
- for i in range(n_windows):
- start = i * step
- end = start + window_length
- cropped_windows.append(sig_2d_permuted[..., start:end])
- if time is not None:
- cropped_times.append(time[start:end])
-
- # Convert list of windows back to numpy array
- cropped_windows = np.array(cropped_windows)
- cropped_times = np.array(cropped_times)
-
- # Move the last axis back to its original position if necessary
- if axis != sig_2d.ndim - 1:
- # Compute the inverse permutation
- inv_axes = np.argsort(axes)
- cropped_windows = np.transpose(cropped_windows, axes=inv_axes)
-
- if time is None:
- return cropped_windows
- else:
- return cropped_windows, cropped_times
-
-
-def main():
- import random
-
- FS = 128
- N_CHS = 19
- RECORD_S = 13
- WINDOW_S = 2
- FACTOR = 0.5
-
- # To pts
- record_pts = int(RECORD_S * FS)
- window_pts = int(WINDOW_S * FS)
-
- # Demo signal
- sig2d = np.random.rand(N_CHS, record_pts)
- time = np.arange(record_pts) / FS
-
- # Main
- xx, tt = crop(sig2d, window_pts, overlap_factor=FACTOR, time=time)
-
- print(f"sig2d.shape: {sig2d.shape}")
- print(f"xx.shape: {xx.shape}")
-
- # Validation
- i_seg = random.randint(0, len(xx) - 1)
- start = int(i_seg * window_pts * FACTOR)
- end = start + window_pts
- assert np.allclose(sig2d[:, start:end], xx[i_seg])
-
-
-if __name__ == "__main__":
- # parser = argparse.ArgumentParser(description='')
- # import argparse
- # # Argument Parser
- import sys
-
- import matplotlib.pyplot as plt
-
- import scitex
-
- # parser.add_argument('--var', '-v', type=int, default=1, help='')
- # parser.add_argument('--flag', '-f', action='store_true', default=False, help='')
- # args = parser.parse_args()
- # Main
- CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(
- sys, plt, verbose=False
- )
- main()
- scitex.session.close(CONFIG, verbose=False, notify=False)
-
-# EOF
diff --git a/src/scitex/dsp/_demo_sig.py b/src/scitex/dsp/_demo_sig.py
deleted file mode 100755
index 348d6ddc3..000000000
--- a/src/scitex/dsp/_demo_sig.py
+++ /dev/null
@@ -1,380 +0,0 @@
-#!/usr/bin/env python3
-# Time-stamp: "2024-11-06 01:45:32 (ywatanabe)"
-# File: ./scitex_repo/src/scitex/dsp/_demo_sig.py
-
-import random
-import sys
-import warnings
-
-import matplotlib.pyplot as plt
-import numpy as np
-from scipy.signal import chirp
-
-try:
- import mne
- from mne.datasets import sample
-
- MNE_AVAILABLE = True
-except ImportError:
- MNE_AVAILABLE = False
- mne = None
- sample = None
-
-try:
- from ripple_detection.simulate import simulate_LFP, simulate_time
-
- RIPPLE_DETECTION_AVAILABLE = True
-except ImportError:
- RIPPLE_DETECTION_AVAILABLE = False
- simulate_LFP = None
- simulate_time = None
-
-try:
- from tensorpac.signals import pac_signals_wavelet
-
- TENSORPAC_AVAILABLE = True
-except ImportError:
- TENSORPAC_AVAILABLE = False
- pac_signals_wavelet = None
-
-from scitex.io import load_configs
-
-# Config
-CONFIG = load_configs(verbose=False)
-
-
-def _check_mne():
- if not MNE_AVAILABLE:
- raise ImportError(
- "MNE-Python is not installed. Please install with: pip install mne"
- )
-
-
-def _check_ripple_detection():
- if not RIPPLE_DETECTION_AVAILABLE:
- raise ImportError(
- "ripple_detection is not installed. Please install with: pip install ripple_detection"
- )
-
-
-def _check_tensorpac():
- if not TENSORPAC_AVAILABLE:
- raise ImportError(
- "tensorpac is not installed. Please install with: pip install tensorpac"
- )
-
-
-# Functions
-def demo_sig(
- sig_type="periodic",
- batch_size=8,
- n_chs=19,
- n_segments=20,
- t_sec=4,
- fs=512,
- freqs_hz=None,
- verbose=False,
-):
- """
- Generate demo signals for various signal types.
-
- Parameters
- ----------
- sig_type : str, optional
- Type of signal to generate. Options are "uniform", "gauss", "periodic", "chirp", "ripple", "meg", "tensorpac", "pac".
- Default is "periodic".
- batch_size : int, optional
- Number of batches to generate. Default is 8.
- n_chs : int, optional
- Number of channels. Default is 19.
- n_segments : int, optional
- Number of segments for tensorpac and pac signals. Default is 20.
- t_sec : float, optional
- Duration of the signal in seconds. Default is 4.
- fs : int, optional
- Sampling frequency in Hz. Default is 512.
- freqs_hz : list or None, optional
- List of frequencies in Hz for periodic signals. If None, random frequencies will be used.
- verbose : bool, optional
- If True, print additional information. Default is False.
-
- Returns
- -------
- tuple
- A tuple containing:
- - np.ndarray: Generated signal(s) with shape (batch_size, n_chs, time_samples) or (batch_size, n_chs, n_segments, time_samples) for tensorpac and pac signals.
- - np.ndarray: Time array.
- - int: Sampling frequency.
- """
- assert sig_type in [
- "uniform",
- "gauss",
- "periodic",
- "chirp",
- "ripple",
- "meg",
- "tensorpac",
- "pac",
- ]
- tt = np.linspace(0, t_sec, int(t_sec * fs), endpoint=False)
-
- if sig_type == "uniform":
- return (
- np.random.uniform(low=-0.5, high=0.5, size=(batch_size, n_chs, len(tt))),
- tt,
- fs,
- )
-
- elif sig_type == "gauss":
- return np.random.randn(batch_size, n_chs, len(tt)), tt, fs
-
- elif sig_type == "meg":
- return (
- _demo_sig_meg(
- batch_size=batch_size,
- n_chs=n_chs,
- t_sec=t_sec,
- fs=fs,
- verbose=verbose,
- ).astype(np.float32)[..., : len(tt)],
- tt,
- fs,
- )
-
- elif sig_type == "tensorpac":
- xx, tt = _demo_sig_tensorpac(
- batch_size=batch_size,
- n_chs=n_chs,
- n_segments=n_segments,
- t_sec=t_sec,
- fs=fs,
- )
- return xx.astype(np.float32)[..., : len(tt)], tt, fs
-
- elif sig_type == "pac":
- xx = _demo_sig_pac(
- batch_size=batch_size,
- n_chs=n_chs,
- n_segments=n_segments,
- t_sec=t_sec,
- fs=fs,
- )
- return xx.astype(np.float32)[..., : len(tt)], tt, fs
-
- else:
- fn_1d = {
- "periodic": _demo_sig_periodic_1d,
- "chirp": _demo_sig_chirp_1d,
- "ripple": _demo_sig_ripple_1d,
- }.get(sig_type)
-
- return (
- (
- np.array(
- [
- fn_1d(
- t_sec=t_sec,
- fs=fs,
- freqs_hz=freqs_hz,
- verbose=verbose,
- )
- for _ in range(int(batch_size * n_chs))
- ]
- )
- .reshape(batch_size, n_chs, -1)
- .astype(np.float32)[..., : len(tt)]
- ),
- tt,
- fs,
- )
-
-
-def _demo_sig_pac(
- batch_size=8,
- n_chs=19,
- t_sec=4,
- fs=512,
- f_pha=10,
- f_amp=100,
- noise=0.8,
- n_segments=20,
- verbose=False,
-):
- """
- Generate a demo signal with phase-amplitude coupling.
-
- Parameters
- ----------
- batch_size (int): Number of batches.
- n_chs (int): Number of channels.
- t_sec (int): Duration of the signal in seconds.
- fs (int): Sampling frequency.
- f_pha (float): Frequency of the phase-modulating signal.
- f_amp (float): Frequency of the amplitude-modulated signal.
- noise (float): Noise level added to the signal.
- n_segments (int): Number of segments.
- verbose (bool): If True, print additional information.
-
- Returns
- -------
- np.array: Generated signals with shape (batch_size, n_chs, n_segments, seq_len).
- """
- seq_len = t_sec * fs
- t = np.arange(seq_len) / fs
- if verbose:
- print(f"Generating signal with length: {seq_len}")
-
- # Create empty array to store the signals
- signals = np.zeros((batch_size, n_chs, n_segments, seq_len))
-
- for b in range(batch_size):
- for ch in range(n_chs):
- for seg in range(n_segments):
- # Phase signal
- theta = np.sin(2 * np.pi * f_pha * t)
- # Amplitude envelope
- amplitude_env = 1 + np.sin(2 * np.pi * f_amp * t)
- # Combine phase and amplitude modulation
- signal = theta * amplitude_env
- # Add Gaussian noise
- signal += noise * np.random.randn(seq_len)
- signals[b, ch, seg, :] = signal
-
- return signals
-
-
-def _demo_sig_tensorpac(
- batch_size=8,
- n_chs=19,
- t_sec=4,
- fs=512,
- f_pha=10,
- f_amp=100,
- noise=0.8,
- n_segments=20,
- verbose=False,
-):
- _check_tensorpac()
- n_times = int(t_sec * fs)
- x_2d, tt = pac_signals_wavelet(
- sf=fs,
- f_pha=f_pha,
- f_amp=f_amp,
- noise=noise,
- n_epochs=n_segments,
- n_times=n_times,
- )
- x_3d = np.stack([x_2d for _ in range(batch_size)], axis=0)
- x_4d = np.stack([x_3d for _ in range(n_chs)], axis=1)
- return x_4d, tt
-
-
-def _demo_sig_meg(batch_size=8, n_chs=19, t_sec=10, fs=512, verbose=False, **kwargs):
- _check_mne()
- data_path = sample.data_path()
- meg_path = data_path / "MEG" / "sample"
- raw_fname = meg_path / "sample_audvis_raw.fif"
- fwd_fname = meg_path / "sample_audvis-meg-eeg-oct-6-fwd.fif"
-
- # Load real data as the template
- raw = mne.io.read_raw_fif(raw_fname, verbose=verbose)
- raw = raw.crop(tmax=t_sec, verbose=verbose)
- raw = raw.resample(fs, verbose=verbose)
- raw.set_eeg_reference(projection=True, verbose=verbose)
-
- return raw.get_data(
- picks=raw.ch_names[: batch_size * n_chs], verbose=verbose
- ).reshape(batch_size, n_chs, -1)
-
-
-def _demo_sig_periodic_1d(t_sec=10, fs=512, freqs_hz=None, verbose=False, **kwargs):
- """Returns a demo signal with the shape (t_sec*fs,)."""
- if freqs_hz is None:
- n_freqs = random.randint(1, 5)
- freqs_hz = np.random.permutation(np.arange(fs))[:n_freqs]
- if verbose:
- print(f"freqs_hz was randomly determined as {freqs_hz}")
-
- n = int(t_sec * fs)
- t = np.linspace(0, t_sec, n, endpoint=False)
-
- summed = np.array(
- [
- np.random.rand() * np.sin((f_hz * t + np.random.rand()) * (2 * np.pi))
- for f_hz in freqs_hz
- ]
- ).sum(axis=0)
- return summed
-
-
-def _demo_sig_chirp_1d(
- t_sec=10, fs=512, low_hz=None, high_hz=None, verbose=False, **kwargs
-):
- if low_hz is None:
- low_hz = random.randint(1, 20)
- if verbose:
- warnings.warn(f"low_hz was randomly determined as {low_hz}.")
-
- if high_hz is None:
- high_hz = random.randint(100, 1000)
- if verbose:
- warnings.warn(f"high_hz was randomly determined as {high_hz}.")
-
- n = int(t_sec * fs)
- t = np.linspace(0, t_sec, n, endpoint=False)
- x = chirp(t, low_hz, t[-1], high_hz)
- x *= 1.0 + 0.5 * np.sin(2.0 * np.pi * 3.0 * t)
- return x
-
-
-def _demo_sig_ripple_1d(t_sec=10, fs=512, **kwargs):
- _check_ripple_detection()
- n_samples = t_sec * fs
- t = simulate_time(n_samples, fs)
- n_ripples = random.randint(1, 5)
- mid_time = np.random.permutation(t)[:n_ripples]
- return simulate_LFP(t, mid_time, noise_amplitude=1.2, ripple_amplitude=5)
-
-
-if __name__ == "__main__":
- import scitex
-
- # Start
- CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt)
- import scitex
-
- SIG_TYPES = [
- "uniform",
- "gauss",
- "periodic",
- "chirp",
- "meg",
- "ripple",
- "tensorpac",
- "pac",
- ]
-
- i_batch, i_ch, i_segment = 0, 0, 0
- fig, axes = scitex.plt.subplots(nrows=len(SIG_TYPES))
- for ax, (i_sig_type, sig_type) in zip(axes, enumerate(SIG_TYPES)):
- xx, tt, fs = demo_sig(sig_type=sig_type)
- if sig_type not in ["tensorpac", "pac"]:
- ax.plot(tt, xx[i_batch, i_ch], label=sig_type)
- else:
- ax.plot(tt, xx[i_batch, i_ch, i_segment], label=sig_type)
- ax.legend(loc="upper left")
- fig.suptitle("Demo signals")
- fig.supxlabel("Time [s]")
- fig.supylabel("Amplitude [?V]")
- scitex.io.save(fig, "traces.png")
-
- # Close
- scitex.session.close(CONFIG)
-
-# EOF
-
-"""
-/home/ywatanabe/proj/entrance/scitex/dsp/_demo_sig.py
-"""
-
-# EOF
diff --git a/src/scitex/dsp/_detect_ripples.py b/src/scitex/dsp/_detect_ripples.py
deleted file mode 100755
index 0e1a3bed5..000000000
--- a/src/scitex/dsp/_detect_ripples.py
+++ /dev/null
@@ -1,216 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-# Time-stamp: "2024-11-05 00:24:54 (ywatanabe)"
-# File: ./scitex_repo/src/scitex/dsp/_detect_ripples.py
-
-import numpy as np
-import pandas as pd
-from scipy.signal import find_peaks
-
-from scitex.gen._norm import to_z
-
-from ._demo_sig import demo_sig
-from ._hilbert import hilbert
-from ._resample import resample
-from .filt import bandpass, gauss
-
-
-def detect_ripples(
- xx,
- fs,
- low_hz,
- high_hz,
- sd=2.0,
- smoothing_sigma_ms=4,
- min_duration_ms=10,
- return_preprocessed_signal=False,
-):
- """
- xx: 2-dimensional (n_chs, seq_len) or 3-dimensional (batch_size, n_chs, seq_len) wide-band signal.
- """
-
- try:
- xx_r, fs_r = _preprocess(xx, fs, low_hz, high_hz, smoothing_sigma_ms)
- df = _find_events(xx_r, fs_r, sd, min_duration_ms)
- df = _drop_ripples_at_edges(df, low_hz, xx_r, fs_r)
- df = _calc_relative_peak_position(df)
- # df = _calc_incidence(df, xx_r, fs_r)
- df = _sort_columns(df)
-
- if not return_preprocessed_signal:
- return df
-
- elif return_preprocessed_signal:
- return df, xx_r, fs_r
-
- except ValueError as e:
- print("Caught an error:", e)
-
-
-def _preprocess(xx, fs, low_hz, high_hz, smoothing_sigma_ms=4):
- # Ensures three dimensional
- if xx.ndim == 2:
- xx = xx[np.newaxis]
- assert xx.ndim == 3
-
- # For readability
- RIPPLE_BANDS = np.vstack([[low_hz, high_hz]])
-
- # Downsampling
- fs_tgt = low_hz * 3
- xx = resample(xx, float(fs), float(fs_tgt))
- fs = fs_tgt
-
- # Subtracts the global mean to reduce false detection due to EMG signal
- xx -= np.nanmean(xx, axis=1, keepdims=True)
-
- # Bandpass Filtering
- xx = (
- (
- bandpass(
- np.array(xx),
- fs_tgt,
- RIPPLE_BANDS,
- )
- )
- .squeeze(-2)
- .astype(np.float64)
- )
-
- # Calculate RMS
- xx = xx**2
- _, xx = hilbert(xx)
- xx = gauss(xx, smoothing_sigma_ms * 1e-3 * fs_tgt).squeeze(-2)
- xx = np.sqrt(xx)
-
- # Scales across channels
- xx = xx.mean(axis=1)
- xx = to_z(xx, dim=-1)
-
- return xx, fs_tgt
-
-
-def _find_events(xx_r, fs_r, sd, min_duration_ms):
- def _find_events_1d(xx_ri, fs_r, sd, min_duration_ms):
- # Finds peaks over the designated standard deviation
- peaks, properties = find_peaks(xx_ri, height=sd)
-
- # Determines the range around each peak (customize as needed)
- peaks_all = []
- peak_ranges = []
- peak_amplitudes_sd = []
-
- for peak in peaks:
- left_bound = np.where(xx_ri[:peak] < 0)[0]
- right_bound = np.where(xx_ri[peak:] < 0)[0]
-
- left_ips = left_bound.max() if left_bound.size > 0 else peak
- right_ips = peak + right_bound.min() if right_bound.size > 0 else peak
-
- # Avoid duplicates: Check if the current peak range is already listed
- if not any(
- (left_ips == start and right_ips == end) for start, end in peak_ranges
- ):
- peaks_all.append(peak)
- peak_ranges.append((left_ips, right_ips))
- peak_amplitudes_sd.append(xx_ri[peak])
-
- # Converts to DataFrame
- if peak_ranges:
- starts, ends = zip(*peak_ranges) if peak_ranges else ([], [])
- df = pd.DataFrame(
- {
- "start_s": np.hstack(starts) / fs_r,
- "peak_s": np.hstack(peaks_all) / fs_r,
- "end_s": np.hstack(ends) / fs_r,
- "peak_amp_sd": np.hstack(peak_amplitudes_sd),
- }
- ).round(3)
- else:
- df = pd.DataFrame(columns=["start_s", "peak_s", "end_s", "peak_amp_sd"])
-
- # Duration
- df["duration_s"] = df.end_s - df.start_s
-
- # Filters events with short duration
- df = df[df.duration_s > (min_duration_ms * 1e-3)]
-
- return df
-
- if xx_r.ndim == 1:
- xx_r = xx_r[np.newaxis, :]
- assert xx_r.ndim == 2
-
- dfs = []
- for i_ch in range(len(xx_r)):
- xx_ri = xx_r[i_ch]
- df_i = _find_events_1d(xx_ri, fs_r, sd, min_duration_ms)
- df_i.index = [i_ch for _ in range(len(df_i))]
- dfs.append(df_i)
- dfs = pd.concat(dfs)
-
- return dfs
-
-
-def _drop_ripples_at_edges(df, low_hz, xx_r, fs_r):
- edge_s = 1 / low_hz * 3
- indi_drop = (df.start_s < edge_s) + (xx_r.shape[-1] / fs_r - edge_s < df.end_s)
- df = df[~indi_drop]
- return df
-
-
-def _calc_relative_peak_position(df):
- delta_s = df.peak_s - df.start_s
- rel_peak = delta_s / df.duration_s
- df["rel_peak_pos"] = np.round(rel_peak, 3)
- return df
-
-
-# def _calc_incidence(df, xx_r, fs_r):
-# n_ripples = len(df)
-# rec_s = xx_r.shape[-1] / fs_r
-# df["incidence_hz"] = n_ripples / rec_s
-# return df
-
-
-def _sort_columns(df):
- sorted_columns = [
- "start_s",
- "end_s",
- "duration_s",
- "peak_s",
- "rel_peak_pos",
- "peak_amp_sd",
- # "incidence_hz",
- ]
- df = df[sorted_columns]
- return df
-
-
-def main():
- xx, tt, fs = demo_sig(sig_type="ripple")
- df = detect_ripples(xx, fs, 80, 140)
- print(df)
-
-
-if __name__ == "__main__":
- import sys
-
- import matplotlib.pyplot as plt
-
- import scitex
-
- # # Argument Parser
- # import argparse
- # parser = argparse.ArgumentParser(description='')
- # parser.add_argument('--var', '-v', type=int, default=1, help='')
- # parser.add_argument('--flag', '-f', action='store_true', default=False, help='')
- # args = parser.parse_args()
- # Main
- CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(
- sys, plt, verbose=False
- )
- main()
- scitex.session.close(CONFIG, verbose=False, notify=False)
-
-# EOF
diff --git a/src/scitex/dsp/_ensure_3d.py b/src/scitex/dsp/_ensure_3d.py
deleted file mode 100755
index 3fba15efd..000000000
--- a/src/scitex/dsp/_ensure_3d.py
+++ /dev/null
@@ -1,18 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-# Time-stamp: "2024-11-05 01:03:47 (ywatanabe)"
-# File: ./scitex_repo/src/scitex/dsp/_ensure_3d.py
-
-from scitex.decorators import signal_fn
-
-
-@signal_fn
-def ensure_3d(x):
- if x.ndim == 1: # assumes (seq_len,)
- x = x.unsqueeze(0).unsqueeze(0)
- elif x.ndim == 2: # assumes (batch_siize, seq_len)
- x = x.unsqueeze(1)
- return x
-
-
-# EOF
diff --git a/src/scitex/dsp/_hilbert.py b/src/scitex/dsp/_hilbert.py
deleted file mode 100755
index 181083411..000000000
--- a/src/scitex/dsp/_hilbert.py
+++ /dev/null
@@ -1,78 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-# Time-stamp: "2024-11-04 02:07:11 (ywatanabe)"
-# File: ./scitex_repo/src/scitex/dsp/_hilbert.py
-
-"""
-This script does XYZ.
-"""
-
-import sys
-
-import matplotlib.pyplot as plt
-
-from scitex.decorators import signal_fn
-from scitex.nn._Hilbert import Hilbert
-
-
-# Functions
-@signal_fn
-def hilbert(
- x,
- dim=-1,
-):
- y = Hilbert(x.shape[-1], dim=dim)(x)
- return y[..., 0], y[..., 1]
-
-
-if __name__ == "__main__":
- import scitex
-
- # Start
- CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt)
-
- # Parameters
- T_SEC = 1.0
- FS = 400
- SIG_TYPE = "chirp"
-
- # Demo signal
- xx, tt, fs = scitex.dsp.demo_sig(t_sec=T_SEC, fs=FS, sig_type=SIG_TYPE)
-
- # Main
- pha, amp = hilbert(
- xx,
- dim=-1,
- )
- # (32, 19, 1280, 2)
-
- # Plots
- fig, axes = scitex.plt.subplots(nrows=2, sharex=True)
- fig.suptitle("Hilbert Transformation")
-
- axes[0].plot(tt, xx[0, 0], label=SIG_TYPE)
- axes[0].plot(tt, amp[0, 0], label="Amplidue")
- axes[0].legend()
- # axes[0].set_xlabel("Time [s]")
- axes[0].set_ylabel("Amplitude [?V]")
-
- axes[1].plot(tt, pha[0, 0], label="Phase")
- axes[1].legend()
-
- axes[1].set_xlabel("Time [s]")
- axes[1].set_ylabel("Phase [rad]")
-
- # plt.show()
- scitex.io.save(fig, "traces.png")
-
- # Close
- scitex.session.close(CONFIG)
-
-# EOF
-
-"""
-/home/ywatanabe/proj/entrance/scitex/dsp/_hilbert.py
-"""
-
-
-# EOF
diff --git a/src/scitex/dsp/_listen.py b/src/scitex/dsp/_listen.py
deleted file mode 100755
index b7bf9bd50..000000000
--- a/src/scitex/dsp/_listen.py
+++ /dev/null
@@ -1,702 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-# Time-stamp: "2024-11-08 09:15:59 (ywatanabe)"
-# File: ./scitex_repo/src/scitex/dsp/_listen.py
-
-import os
-import sys
-
-import matplotlib.pyplot as plt
-import sounddevice as sd
-
-os.environ["PULSE_SERVER"] = "unix:/mnt/wslg/PulseServer"
-
-# # WSL2 Sound Support
-# export PULSE_SERVER=unix:/mnt/wslg/PulseServer
-
-
-def list_and_select_device() -> int:
- """
- List available audio devices and prompt user to select one.
-
- Example
- -------
- >>> device_id = list_and_select_device()
- Available audio devices:
- ...
- Enter the ID of the device you want to use:
-
- Returns
- -------
- int
- Selected device ID
- """
- try:
- print("Available audio devices:")
- devices = sd.query_devices()
- print(devices)
- device_id = int(input("Enter the ID of the device you want to use: "))
- if device_id not in range(len(devices)):
- raise ValueError(f"Invalid device ID: {device_id}")
- return device_id
- except (ValueError, sd.PortAudioError) as err:
- print(f"Error during device selection: {err}")
- return 0
-
-
-if __name__ == "__main__":
- import scitex
-
- CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt)
-
- signal, time_points, sampling_freq = scitex.dsp.demo_sig("chirp")
-
- device_id = list_and_select_device()
- sd.default.device = device_id
-
- listen(signal, sampling_freq)
-
- scitex.session.close(CONFIG)
-
-# def play_audio(
-# samples: np.ndarray, fs: int = 44100, channels: int = 1
-# ) -> None:
-# """Play audio using PyAudio"""
-# print("Initializing PyAudio...")
-# p = pyaudio.PyAudio()
-
-# # List available devices
-# print("\nAvailable audio devices:")
-# for i in range(p.get_device_count()):
-# dev = p.get_device_info_by_index(i)
-# print(f"Device {i}: {dev['name']}")
-
-# try:
-# # Rest of the code remains the same
-# if samples.dtype != np.float32:
-# print(f"Converting from {samples.dtype} to float32...")
-# samples = samples.astype(np.float32)
-
-# if len(samples) == 0:
-# print("No input samples, creating test tone...")
-# duration = 1
-# t = np.linspace(0, duration, int(fs * duration))
-# samples = np.sin(2 * np.pi * 440 * t)
-
-# print(f"Opening audio stream (fs={fs}Hz, channels={channels})...")
-# stream = p.open(
-# format=pyaudio.paFloat32, channels=channels, rate=fs, output=True
-# )
-
-# print("Playing audio...")
-# stream.write(samples.tobytes())
-
-# except Exception as e:
-# print(f"Error: {e}")
-# finally:
-# print("Cleaning up...")
-# if "stream" in locals():
-# stream.stop_stream()
-# stream.close()
-# p.terminate()
-# print("Done.")
-
-# if __name__ == "__main__":
-# # Test with a simple sine wave
-# duration = 2
-# fs = 44100
-# t = np.linspace(0, duration, int(fs * duration))
-# test_signal = np.sin(2 * np.pi * 440 * t) # 440 Hz
-# play_audio(test_signal, fs)
-
-# # def record_audio(
-# # duration: float = 5.0, fs: int = 44100, channels: int = 1
-# # ) -> np.ndarray:
-# # """Record audio using PyAudio with WSL2 compatibility"""
-# # chunk = 1024
-# # format = pyaudio.paFloat32
-
-# # p = pyaudio.PyAudio()
-
-# # # List available devices
-# # for i in range(p.get_device_count()):
-# # print(p.get_device_info_by_index(i))
-
-# # try:
-# # stream = p.open(
-# # format=format,
-# # channels=channels,
-# # rate=fs,
-# # input=True,
-# # frames_per_buffer=chunk,
-# # input_device_index=None, # Use default device
-# # )
-
-# # frames = []
-# # print("Recording...")
-
-# # for _ in range(0, int(fs / chunk * duration)):
-# # data = stream.read(chunk)
-# # frames.append(data)
-
-# # print("Done recording")
-# # return np.frombuffer(b"".join(frames), dtype=np.float32)
-
-# # finally:
-# # if "stream" in locals():
-# # stream.stop_stream()
-# # stream.close()
-# # p.terminate()
-
-# # def record_audio(
-# # duration: float = 5.0, fs: int = 44100, channels: int = 1
-# # ) -> np.ndarray:
-# # """
-# # Record audio using PyAudio.
-# # """
-# # chunk = 1024
-# # format = pyaudio.paFloat32
-
-# # p = pyaudio.PyAudio()
-
-# # stream = p.open(
-# # format=format,
-# # channels=channels,
-# # rate=fs,
-# # input=True,
-# # frames_per_buffer=chunk,
-# # )
-
-# # frames = []
-
-# # for _ in range(0, int(fs / chunk * duration)):
-# # data = stream.read(chunk)
-# # frames.append(data)
-
-# # stream.stop_stream()
-# # stream.close()
-# # p.terminate()
-
-# # return np.frombuffer(b"".join(frames), dtype=np.float32)
-
-# # if __name__ == "__main__":
-# # # Test recording
-# # audio = record_audio(duration=3.0)
-# # print(f"Recorded {len(audio)} samples")
-
-# # # #!/usr/bin/env python3
-# # # # -*- coding: utf-8 -*-
-# # # # Time-stamp: "2024-11-08 08:56:12 (ywatanabe)"
-# # # # File: ./scitex_repo/src/scitex/dsp/_listen.py
-
-# # # import os
-# # # import sys
-# # # from typing import Any, Dict, Literal, Tuple
-
-# # # import matplotlib.pyplot as plt
-# # # import scitex
-# # # import numpy as np
-# # # from scipy.signal import resample
-
-# # # os.environ["DISPLAY"] = ":0" # Set a default display
-
-# # # # Avoid GUI/clipboard dependencies
-# # # def dummy_clipboard_get():
-# # # raise NotImplementedError(
-# # # "Clipboard not available in headless environment"
-# # # )
-
-# # # try:
-# # # from IPython import get_ipython
-
-# # # ipython = get_ipython()
-# # # if ipython is not None:
-# # # ipython.hooks.clipboard_get = dummy_clipboard_get
-# # # except ImportError:
-# # # pass
-
-# # # """
-# # # Functionality:
-# # # - Provides audio playback and signal sonification functionality
-# # # - Supports multiple sonification methods (frequency shift, AM/FM Modulation)
-# # # - Includes device selection and audio information display utilities
-# # # Input:
-# # # - Multichannel signal arrays (numpy.ndarray)
-# # # - Sampling frequency and sonification parameters
-# # # Output:
-# # # - Audio playback through specified output device
-# # # Prerequisites:
-# # # - PortAudio library (install with: sudo apt-get install portaudio19-dev)
-# # # - sounddevice package
-# # # """
-
-# # # """Imports"""
-# # # """Config"""
-# # # CONFIG = scitex.gen.load_configs()
-
-# # # """Functions"""
-
-# # # def frequency_shift(
-# # # signal: np.ndarray,
-# # # shift_factor: int = 200,
-# # # ) -> np.ndarray:
-# # # """
-# # # Shift signal frequencies by resampling.
-
-# # # Example
-# # # -------
-# # # >>> signal = np.sin(2 * np.pi * 10 * np.linspace(0, 1, 1000))
-# # # >>> shifted = frequency_shift(signal, shift_factor=200)
-
-# # # Parameters
-# # # ----------
-# # # signal : np.ndarray
-# # # Input signal to be frequency shifted
-# # # shift_factor : int
-# # # Frequency multiplication factor
-
-# # # Returns
-# # # -------
-# # # np.ndarray
-# # # Frequency shifted signal
-# # # """
-# # # num_samples = int(len(signal) * shift_factor)
-# # # return resample(signal, num_samples)
-
-# # # def am_Modulation(
-# # # signal: np.ndarray, carrier_freq: float = 440, fs: int = 44_100
-# # # ) -> np.ndarray:
-# # # """
-# # # Perform amplitude Modulation on input signal.
-
-# # # Example
-# # # -------
-# # # >>> signal = np.sin(2 * np.pi * 10 * np.linspace(0, 1, 1000))
-# # # >>> modulated = am_Modulation(signal, carrier_freq=440, fs=44100)
-
-# # # Parameters
-# # # ----------
-# # # signal : np.ndarray
-# # # Input signal to modulate
-# # # carrier_freq : float
-# # # Carrier frequency in Hz
-# # # fs : int
-# # # Sampling frequency in Hz
-
-# # # Returns
-# # # -------
-# # # np.ndarray
-# # # Amplitude modulated signal
-# # # """
-# # # t = np.arange(len(signal)) / fs
-# # # carrier = np.sin(2 * np.pi * carrier_freq * t)
-# # # return (1 + signal) * carrier
-
-# # # def fm_Modulation(
-# # # signal: np.ndarray,
-# # # carrier_freq: float = 440,
-# # # sensitivity: float = 0.5,
-# # # fs: int = 44_100,
-# # # ) -> np.ndarray:
-# # # """
-# # # Perform frequency Modulation on input signal.
-
-# # # Example
-# # # -------
-# # # >>> signal = np.sin(2 * np.pi * 10 * np.linspace(0, 1, 1000))
-# # # >>> modulated = fm_Modulation(signal, carrier_freq=440, sensitivity=0.5, fs=44100)
-
-# # # Parameters
-# # # ----------
-# # # signal : np.ndarray
-# # # Input signal to modulate
-# # # carrier_freq : float
-# # # Carrier frequency in Hz
-# # # sensitivity : float
-# # # Frequency sensitivity factor
-# # # fs : int
-# # # Sampling frequency in Hz
-
-# # # Returns
-# # # -------
-# # # np.ndarray
-# # # Frequency modulated signal
-# # # """
-# # # t = np.arange(len(signal)) / fs
-# # # phase = 2 * np.pi * carrier_freq * t + sensitivity * np.cumsum(signal) / fs
-# # # return np.sin(phase)
-
-# # # def sonify_eeg(
-# # # signal_array: np.ndarray,
-# # # sampling_freq: int,
-# # # method: Literal["shift", "am", "fm"] = "shift",
-# # # channels: Tuple[int, ...] = (0, 1),
-# # # target_fs: int = 44_100,
-# # # **kwargs: Dict[str, Any],
-# # # ) -> None:
-# # # """
-# # # Convert EEG signal to audio using various sonification methods.
-
-# # # Example
-# # # -------
-# # # >>> eeg_data = np.random.randn(1, 32, 1000) # Mock EEG data
-# # # >>> sonify_eeg(eeg_data, 250, method='fm', channels=(0,1))
-
-# # # Parameters
-# # # ----------
-# # # signal_array : np.ndarray
-# # # EEG signal array of shape (batch_size, n_channels, sequence_length)
-# # # sampling_freq : int
-# # # Original sampling frequency of the EEG signal
-# # # method : {'shift', 'am', 'fm'}
-# # # Sonification method to use
-# # # channels : Tuple[int, ...]
-# # # Channels to include in sonification
-# # # target_fs : int
-# # # Target audio sampling frequency
-# # # **kwargs : Dict[str, Any]
-# # # Additional parameters for specific sonification methods
-
-# # # Returns
-# # # -------
-# # # None
-# # # """
-
-# # # if not isinstance(signal_array, np.ndarray):
-# # # signal_array = np.array(signal_array)
-
-# # # if len(signal_array.shape) != 3:
-# # # raise ValueError(f"Expected 3D array, got shape {signal_array.shape}")
-
-# # # if max(channels) >= signal_array.shape[1]:
-# # # raise ValueError(
-# # # f"Channel index {max(channels)} out of range (max: {signal_array.shape[1]-1})"
-# # # )
-
-# # # selected_channels = signal_array[:, channels, :].mean(axis=1)
-# # # signal = selected_channels.mean(axis=0)
-
-# # # # Normalize
-# # # signal = signal / np.max(np.abs(signal))
-
-# # # # Apply selected method
-# # # if method == "shift":
-# # # audio = frequency_shift(signal, kwargs.get("shift_factor", 200))
-# # # elif method == "am":
-# # # audio = am_Modulation(
-# # # signal, kwargs.get("carrier_freq", 440), target_fs
-# # # )
-# # # elif method == "fm":
-# # # audio = fm_Modulation(
-# # # signal,
-# # # kwargs.get("carrier_freq", 440),
-# # # kwargs.get("sensitivity", 0.5),
-# # # target_fs,
-# # # )
-# # # else:
-# # # raise ValueError(f"Unknown method: {method}")
-
-# # # if len(audio) < 100:
-# # # raise ValueError("Audio signal too short after processing")
-
-# # # sd.play(audio, target_fs)
-# # # sd.wait()
-
-# # # def listen(
-# # # signal_array: np.ndarray,
-# # # sampling_freq: int,
-# # # channels: Tuple[int, ...] = (0, 1),
-# # # target_fs: int = 44_100,
-# # # ) -> None:
-# # # """
-# # # Play selected channels of a multichannel signal array as audio.
-
-# # # Example
-# # # -------
-# # # >>> signal = np.random.randn(1, 2, 1000) # Random stereo signal
-# # # >>> listen(signal, 16000, channels=(0, 1))
-
-# # # Parameters
-# # # ----------
-# # # signal_array : np.ndarray
-# # # Signal array of shape (batch_size, n_channels, sequence_length)
-# # # sampling_freq : int
-# # # Original sampling frequency of the signal
-# # # channels : Tuple[int, ...]
-# # # Tuple of channel indices to listen to
-# # # target_fs : int
-# # # Target sampling frequency for playback
-
-# # # Returns
-# # # -------
-# # # None
-# # # """
-# # # if not isinstance(signal_array, np.ndarray):
-# # # signal_array = np.array(signal_array)
-
-# # # if len(signal_array.shape) != 3:
-# # # raise ValueError(f"Expected 3D array, got shape {signal_array.shape}")
-
-# # # if max(channels) >= signal_array.shape[1]:
-# # # raise ValueError(
-# # # f"Channel index {max(channels)} out of range (max: {signal_array.shape[1]-1})"
-# # # )
-
-# # # selected_channels = signal_array[:, channels, :].mean(axis=1)
-# # # audio_signal = selected_channels.mean(axis=0)
-
-# # # if sampling_freq != target_fs:
-# # # num_samples = int(round(len(audio_signal) * target_fs / sampling_freq))
-# # # audio_signal = resample(audio_signal, num_samples)
-
-# # # sd.play(audio_signal, target_fs)
-# # # sd.wait()
-
-# # # def print_device_info() -> None:
-# # # """
-# # # Display information about the default audio output device.
-
-# # # Example
-# # # -------
-# # # >>> print_device_info()
-# # # Default Output Device Info:
-# # #
-# # # """
-# # # try:
-# # # device_info = sd.query_devices(kind="output")
-# # # print(f"Default Output Device Info: \n{device_info}")
-# # # except sd.PortAudioError as err:
-# # # print(f"Error querying audio devices: {err}")
-
-# # # if __name__ == "__main__":
-# # # import scitex
-
-# # # CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt)
-
-# # # # Generate a test signal if demo_sig fails
-# # # try:
-# # # signal, time_points, sampling_freq = scitex.dsp.demo_sig("chirp")
-# # # except Exception as err:
-# # # print(f"Failed to load demo signal: {err}")
-# # # # Generate a simple test signal
-# # # duration = 2 # seconds
-# # # sampling_freq = 1000 # Hz
-# # # t = np.linspace(0, duration, int(duration * sampling_freq))
-# # # test_signal = np.sin(2 * np.pi * 10 * t) # 10 Hz sine wave
-# # # signal = test_signal.reshape(1, 1, -1)
-
-# # # # Try to get audio device
-# # # try:
-# # # device_id = list_and_select_device()
-# # # sd.default.device = device_id
-# # # except Exception as err:
-# # # print(f"Failed to set audio device: {err}")
-# # # print("Using default audio device")
-# # # device_id = None
-
-# # # # Test different sonification methods with error handling
-# # # methods = ["shift", "am", "fm"]
-# # # for method in methods:
-# # # try:
-# # # print(f"\nTesting {method} sonification...")
-# # # sonify_eeg(signal, sampling_freq, method=method)
-# # # except Exception as err:
-# # # print(f"Failed to play {method} sonification: {err}")
-
-# # # scitex.session.close(CONFIG)
-
-# # # # EOF
-
-# # # # #!/usr/bin/env python3
-# # # # # -*- coding: utf-8 -*-
-# # # # # Time-stamp: "2024-11-07 18:58:37 (ywatanabe)"
-# # # # # File: ./scitex_repo/src/scitex/dsp/_listen.py
-
-# # # # import sys
-# # # # from typing import Tuple
-
-# # # # import matplotlib.pyplot as plt
-# # # # import scitex
-# # # # import numpy as np
-# # # # import sounddevice as sd
-# # # # from scipy.signal import resample
-
-# # # # """
-# # # # Functionality:
-# # # # - Provides audio playback functionality for multichannel signal arrays
-# # # # - Includes device selection and audio information display utilities
-# # # # Input:
-# # # # - Multichannel signal arrays (numpy.ndarray)
-# # # # - Sampling frequency and channel selection
-# # # # Output:
-# # # # - Audio playback through specified output device
-# # # # Prerequisites:
-# # # # - PortAudio library (install with: sudo apt-get install portaudio19-dev)
-# # # # - sounddevice package
-# # # # """
-
-# # # # """Imports"""
-# # # # """Config"""
-# # # # CONFIG = scitex.gen.load_configs()
-
-# # # # """Functions"""
-# # # # def listen(
-# # # # signal_array: np.ndarray,
-# # # # sampling_freq: int,
-# # # # channels: Tuple[int, ...] = (0, 1),
-# # # # target_fs: int = 44_100,
-# # # # ) -> None:
-# # # # """
-# # # # Play selected channels of a multichannel signal array as audio.
-
-# # # # Example
-# # # # -------
-# # # # >>> signal = np.random.randn(1, 2, 1000) # Random stereo signal
-# # # # >>> listen(signal, 16000, channels=(0, 1))
-
-# # # # Parameters
-# # # # ----------
-# # # # signal_array : np.ndarray
-# # # # Signal array of shape (batch_size, n_channels, sequence_length)
-# # # # sampling_freq : int
-# # # # Original sampling frequency of the signal
-# # # # channels : Tuple[int, ...]
-# # # # Tuple of channel indices to listen to
-# # # # target_fs : int
-# # # # Target sampling frequency for playback
-
-# # # # Returns
-# # # # -------
-# # # # None
-# # # # """
-# # # # if not isinstance(signal_array, np.ndarray):
-# # # # signal_array = np.array(signal_array)
-
-# # # # if len(signal_array.shape) != 3:
-# # # # raise ValueError(f"Expected 3D array, got shape {signal_array.shape}")
-
-# # # # if max(channels) >= signal_array.shape[1]:
-# # # # raise ValueError(f"Channel index {max(channels)} out of range (max: {signal_array.shape[1]-1})")
-
-# # # # selected_channels = signal_array[:, channels, :].mean(axis=1)
-# # # # audio_signal = selected_channels.mean(axis=0)
-
-# # # # if sampling_freq != target_fs:
-# # # # num_samples = int(round(len(audio_signal) * target_fs / sampling_freq))
-# # # # audio_signal = resample(audio_signal, num_samples)
-
-# # # # sd.play(audio_signal, target_fs)
-# # # # sd.wait()
-
-# # # # def print_device_info() -> None:
-# # # # """
-# # # # Display information about the default audio output device.
-
-# # # # Example
-# # # # -------
-# # # # >>> print_device_info()
-# # # # Default Output Device Info:
-# # # #
-# # # # """
-# # # # try:
-# # # # device_info = sd.query_devices(kind="output")
-# # # # print(f"Default Output Device Info: \n{device_info}")
-# # # # except sd.PortAudioError as err:
-# # # # print(f"Error querying audio devices: {err}")
-
-# # # # def list_and_select_device() -> int:
-# # # # """
-# # # # List available audio devices and prompt user to select one.
-
-# # # # Example
-# # # # -------
-# # # # >>> device_id = list_and_select_device()
-# # # # Available audio devices:
-# # # # ...
-# # # # Enter the ID of the device you want to use:
-
-# # # # Returns
-# # # # -------
-# # # # int
-# # # # Selected device ID
-# # # # """
-# # # # try:
-# # # # print("Available audio devices:")
-# # # # devices = sd.query_devices()
-# # # # print(devices)
-# # # # device_id = int(input("Enter the ID of the device you want to use: "))
-# # # # if device_id not in range(len(devices)):
-# # # # raise ValueError(f"Invalid device ID: {device_id}")
-# # # # return device_id
-# # # # except (ValueError, sd.PortAudioError) as err:
-# # # # print(f"Error during device selection: {err}")
-# # # # return 0
-
-# # # # def frequency_shift(signal: np.ndarray, shift_factor: int = 200) -> np.ndarray:
-# # # # """Direct frequency shifting"""
-# # # # num_samples = int(len(signal) * shift_factor)
-# # # # return resample(signal, num_samples)
-
-# # # # def am_Modulation(signal: np.ndarray, carrier_freq: float = 440, fs: int = 44100) -> np.ndarray:
-# # # # """Amplitude Modulation"""
-# # # # t = np.arange(len(signal)) / fs
-# # # # carrier = np.sin(2 * np.pi * carrier_freq * t)
-# # # # return (1 + signal) * carrier
-
-# # # # def fm_Modulation(signal: np.ndarray, carrier_freq: float = 440, sens: float = 0.5, fs: int = 44100) -> np.ndarray:
-# # # # """Frequency Modulation"""
-# # # # t = np.arange(len(signal)) / fs
-# # # # phase = 2 * np.pi * carrier_freq * t + sens * np.cumsum(signal) / fs
-# # # # return np.sin(phase)
-
-# # # # def sonify_eeg(
-# # # # signal_array: np.ndarray,
-# # # # sampling_freq: int,
-# # # # method: str = 'shift',
-# # # # channels: Tuple[int, ...] = (0, 1),
-# # # # target_fs: int = 44100,
-# # # # **kwargs
-# # # # ) -> None:
-# # # # """Main sonification function"""
-# # # # selected_channels = signal_array[:, channels, :].mean(axis=1)
-# # # # signal = selected_channels.mean(axis=0)
-
-# # # # # Normalize
-# # # # signal = signal / np.max(np.abs(signal))
-
-# # # # # Apply selected method
-# # # # if method == 'shift':
-# # # # audio = frequency_shift(signal, kwargs.get('shift_factor', 200))
-# # # # elif method == 'am':
-# # # # audio = am_Modulation(signal, kwargs.get('carrier_freq', 440), target_fs)
-# # # # elif method == 'fm':
-# # # # audio = fm_Modulation(signal, kwargs.get('carrier_freq', 440), kwargs.get('sensitivity', 0.5), target_fs)
-# # # # else:
-# # # # raise ValueError(f"Unknown method: {method}")
-
-# # # # sd.play(audio, target_fs)
-# # # # sd.wait()
-
-# # # # if __name__ == "__main__":
-# # # # import scitex
-
-# # # # CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt)
-
-# # # # signal, time_points, sampling_freq = scitex.dsp.demo_sig("chirp")
-
-# # # # device_id = list_and_select_device()
-# # # # sd.default.device = device_id
-
-# # # # listen(signal, sampling_freq)
-
-# # # # scitex.session.close(CONFIG)
-
-# # # #
-
-# # #
-
-# #
-
-#
-
-# EOF
diff --git a/src/scitex/dsp/_misc.py b/src/scitex/dsp/_misc.py
deleted file mode 100755
index 0437ffe41..000000000
--- a/src/scitex/dsp/_misc.py
+++ /dev/null
@@ -1,30 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-# Time-stamp: "2024-11-05 01:03:32 (ywatanabe)"
-# File: ./scitex_repo/src/scitex/dsp/_misc.py
-
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-# Time-stamp: "2024-04-05 12:14:08 (ywatanabe)"
-
-from scitex.decorators import torch_fn
-
-
-@torch_fn
-def ensure_3d(x):
- if x.ndim == 1: # assumes (seq_len,)
- x = x.unsqueeze(0).unsqueeze(0)
- elif x.ndim == 2: # assumes (batch_siize, seq_len)
- x = x.unsqueeze(1)
- return x
-
-
-# @torch_fn
-# def unbias(x, dim=-1, fn="mean"):
-# if fn == "mean":
-# return x - x.mean(dim=dim, keepdims=True)
-# if fn == "min":
-# return x - x.min(dim=dim, keepdims=True)[0]
-
-
-# EOF
diff --git a/src/scitex/dsp/_mne.py b/src/scitex/dsp/_mne.py
deleted file mode 100755
index 36f3451b4..000000000
--- a/src/scitex/dsp/_mne.py
+++ /dev/null
@@ -1,43 +0,0 @@
-#!/usr/bin/env python3
-# Time-stamp: "2024-11-04 02:07:36 (ywatanabe)"
-# File: ./scitex_repo/src/scitex/dsp/_mne.py
-
-try:
- import mne
-
- MNE_AVAILABLE = True
-except ImportError:
- MNE_AVAILABLE = False
- mne = None
-
-import pandas as pd
-
-from .params import EEG_MONTAGE_1020
-
-
-def get_eeg_pos(channel_names=EEG_MONTAGE_1020):
- if not MNE_AVAILABLE:
- raise ImportError(
- "MNE-Python is not installed. Please install with: pip install mne"
- )
- # Load the standard 10-20 montage
- standard_montage = mne.channels.make_standard_montage("standard_1020")
- standard_montage.ch_names = [
- ch_name.upper() for ch_name in standard_montage.ch_names
- ]
-
- # Get the positions of the electrodes in the standard montage
- positions = standard_montage.get_positions()
-
- df = pd.DataFrame(positions["ch_pos"])[channel_names]
-
- df.set_index(pd.Series(["x", "y", "z"]))
-
- return df
-
-
-if __name__ == "__main__":
- print(get_eeg_pos())
-
-
-# EOF
diff --git a/src/scitex/dsp/_modulation_index.py b/src/scitex/dsp/_modulation_index.py
deleted file mode 100755
index 6a075a9df..000000000
--- a/src/scitex/dsp/_modulation_index.py
+++ /dev/null
@@ -1,93 +0,0 @@
-#!/usr/bin/env python3
-# Time-stamp: "2024-11-04 02:09:55 (ywatanabe)"
-# File: ./scitex_repo/src/scitex/dsp/_modulation_index.py
-
-try:
- import torch
-
- TORCH_AVAILABLE = True
-except ImportError:
- TORCH_AVAILABLE = False
- torch = None
-
-from scitex.decorators import signal_fn
-
-if TORCH_AVAILABLE:
- from scitex.nn._ModulationIndex import ModulationIndex
-
-
-@signal_fn
-def modulation_index(pha, amp, n_bins=18, amp_prob=False):
- """
- pha: (batch_size, n_chs, n_freqs_pha, n_segments, seq_len)
- amp: (batch_size, n_chs, n_freqs_amp, n_segments, seq_len)
- """
- if not TORCH_AVAILABLE:
- raise ImportError(
- "PyTorch is not installed. Please install with: pip install torch"
- )
- return ModulationIndex(n_bins=n_bins, amp_prob=amp_prob)(pha, amp)
-
-
-def _reshape(x, batch_size=2, n_chs=4):
- return (
- torch.tensor(x)
- .float()
- .unsqueeze(0)
- .unsqueeze(0)
- .repeat(batch_size, n_chs, 1, 1, 1)
- )
-
-
-if __name__ == "__main__":
- import sys
-
- import matplotlib.pyplot as plt
-
- import scitex
-
- # Start
- CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(
- sys, plt, fig_scale=3
- )
-
- # Parameters
- FS = 512
- T_SEC = 5
-
- # Demo signal
- xx, tt, fs = scitex.dsp.demo_sig(fs=FS, t_sec=T_SEC, sig_type="tensorpac")
- # xx.shape: (8, 19, 20, 512)
-
- # Tensorpac
- (
- pha,
- amp,
- freqs_pha,
- freqs_amp,
- pac_tp,
- ) = scitex.dsp.utils.pac.calc_pac_with_tensorpac(xx, fs, t_sec=T_SEC)
-
- # GPU calculation with scitex.dsp.nn.ModulationIndex
- pha, amp = _reshape(pha), _reshape(amp)
- pac_scitex = scitex.dsp.modulation_index(pha, amp).cpu().numpy()
- i_batch, i_ch = 0, 0
- pac_scitex = pac_scitex[i_batch, i_ch]
-
- # Plots
- fig = scitex.dsp.utils.pac.plot_PAC_scitex_vs_tensorpac(
- pac_scitex, pac_tp, freqs_pha, freqs_amp
- )
- fig.suptitle("MI (modulation index) calculation")
- scitex.io.save(fig, "modulation_index.png")
-
- # Close
- scitex.session.close(CONFIG)
-
-# EOF
-
-"""
-/home/ywatanabe/proj/entrance/scitex/dsp/_modulation_index.py
-"""
-
-# EOF
diff --git a/src/scitex/dsp/_pac.py b/src/scitex/dsp/_pac.py
deleted file mode 100755
index d4e31dcf8..000000000
--- a/src/scitex/dsp/_pac.py
+++ /dev/null
@@ -1,337 +0,0 @@
-#!/usr/bin/env python3
-# Time-stamp: "2024-11-26 22:24:40 (ywatanabe)"
-# File: ./scitex_repo/src/scitex/dsp/_pac.py
-
-THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/dsp/_pac.py"
-
-import sys
-
-import matplotlib.pyplot as plt
-import numpy as np
-
-try:
- import torch
-
- TORCH_AVAILABLE = True
-except ImportError:
- TORCH_AVAILABLE = False
- torch = None
-
-
-from scitex.decorators import signal_fn
-
-if TORCH_AVAILABLE:
- from scitex.nn._PAC import PAC
-
-
-def _check_torch():
- if not TORCH_AVAILABLE:
- raise ImportError(
- "PyTorch is not installed. Please install with: pip install torch"
- )
-
-
-"""
-scitex.dsp.pac function
-"""
-
-
-# @batch_fn
-@signal_fn
-def pac(
- x,
- fs,
- pha_start_hz=2,
- pha_end_hz=20,
- pha_n_bands=100,
- amp_start_hz=60,
- amp_end_hz=160,
- amp_n_bands=100,
- device="cuda",
- batch_size=1,
- batch_size_ch=-1,
- fp16=False,
- trainable=False,
- n_perm=None,
- amp_prob=False,
-):
- """
- Compute the phase-amplitude coupling (PAC) for signals. This function automatically handles inputs as
- PyTorch tensors, NumPy arrays, or pandas DataFrames.
-
- Arguments:
- - x (torch.Tensor | np.ndarray | pd.DataFrame): Input signal. Shape can be either (batch_size, n_chs, seq_len) or
- - fs (float): Sampling frequency of the input signal.
- - pha_start_hz (float, optional): Start frequency for phase bands. Default is 2 Hz.
- - pha_end_hz (float, optional): End frequency for phase bands. Default is 20 Hz.
- - pha_n_bands (int, optional): Number of phase bands. Default is 100.
- - amp_start_hz (float, optional): Start frequency for amplitude bands. Default is 60 Hz.
- - amp_end_hz (float, optional): End frequency for amplitude bands. Default is 160 Hz.
- - amp_n_bands (int, optional): Number of amplitude bands. Default is 100.
-
- Returns:
- - torch.Tensor: PAC values. Shape: (batch_size, n_chs, pha_n_bands, amp_n_bands)
- - numpy.ndarray: Phase bands used for the computation.
- - numpy.ndarray: Amplitude bands used for the computation.
-
- Example:
- FS = 512
- T_SEC = 4
- xx, tt, fs = scitex.dsp.demo_sig(
- batch_size=1, n_chs=1, fs=FS, t_sec=T_SEC, sig_type="tensorpac"
- )
- pac, pha_mids_hz, amp_mids_hz = scitex.dsp.pac(xx, fs)
- """
- _check_torch()
-
- def process_ch_batching(m, x, batch_size_ch, device):
- n_chs = x.shape[1]
- n_batches = (n_chs + batch_size_ch - 1) // batch_size_ch
-
- agg = []
- for ii in range(n_batches):
- start, end = batch_size_ch * ii, min(batch_size_ch * (ii + 1), n_chs)
- _pac = m(x[:, start:end, :].to(device)).detach().cpu()
- agg.append(_pac)
-
- # return np.concatenate(agg, axis=1)
- return torch.cat(agg, dim=1)
-
- m = PAC(
- x.shape[-1],
- fs,
- pha_start_hz=pha_start_hz,
- pha_end_hz=pha_end_hz,
- pha_n_bands=pha_n_bands,
- amp_start_hz=amp_start_hz,
- amp_end_hz=amp_end_hz,
- amp_n_bands=amp_n_bands,
- fp16=fp16,
- trainable=trainable,
- n_perm=n_perm,
- amp_prob=amp_prob,
- ).to(device)
-
- if batch_size_ch == -1:
- return m(x.to(device)), m.PHA_MIDS_HZ, m.AMP_MIDS_HZ
- else:
- return (
- process_ch_batching(m, x, batch_size_ch, device),
- m.PHA_MIDS_HZ,
- m.AMP_MIDS_HZ,
- )
-
-
-if __name__ == "__main__":
- import matplotlib.pyplot as plt
-
- import scitex
-
- # Start
- CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt)
-
- pac, freqs_pha, freqs_amp = scitex.dsp.pac(
- np.random.rand(1, 16, 24000),
- 400,
- batch_size=1,
- batch_size_ch=8,
- fp16=True,
- n_perm=16,
- )
-
-# # Parameters
-# FS = 512
-# T_SEC = 4
-# # IS_TRAINABLE = False
-# # FP16 = True
-
-# for IS_TRAINABLE in [True, False]:
-# for FP16 in [True, False]:
-
-# # Demo signal
-# xx, tt, fs = scitex.dsp.demo_sig(
-# batch_size=1, n_chs=1, fs=FS, t_sec=T_SEC, sig_type="pac"
-# )
-
-
-# # scitex.str.print_debug()
-# # xx = np.random.rand(1,16,24000)
-# # fs = 400
-
-# # scitex calculation
-# pac_scitex, pha_mids_scitex, amp_mids_scitex = scitex.dsp.pac(
-# xx,
-# fs,
-# pha_n_bands=50,
-# amp_n_bands=30,
-# trainable=IS_TRAINABLE,
-# fp16=FP16,
-# )
-# i_batch, i_ch = 0, 0
-# pac_scitex = pac_scitex[i_batch, i_ch]
-
-# printc(type(pac_scitex))
-
-# # Tensorpac calculation
-# (
-# _,
-# _,
-# _pha_mids_tp,
-# _amp_mids_tp,
-# pac_tp,
-# ) = scitex.dsp.utils.pac.calc_pac_with_tensorpac(xx, fs, T_SEC)
-
-# # Validates the consitency in frequency definitions
-# assert np.allclose(
-# pha_mids_scitex, _pha_mids_tp
-# )
-# assert np.allclose(
-# amp_mids_scitex, _amp_mids_tp
-# )
-
-# scitex.io.save(
-# (pac_scitex, pac_tp, pha_mids_scitex, amp_mids_scitex),
-# "./data/cache.npz",
-# )
-
-# # ################################################################################
-# # # cache
-# # pac_scitex, pac_tp, pha_mids_scitex, amp_mids_scitex = scitex.io.load(
-# # "./data/cache.npz"
-# # )
-# # ################################################################################
-
-# # Plots
-# fig = scitex.dsp.utils.pac.plot_PAC_scitex_vs_tensorpac(
-# pac_scitex, pac_tp, pha_mids_scitex, amp_mids_scitex
-# )
-# fig.suptitle(
-# "Phase-Amplitude Coupling calculation\n\n(Bandpass Filtering -> Hilbert Transformation-> Modulation Index)"
-# )
-# plt.show()
-
-# scitex.gen.reload(scitex.dsp)
-
-# # Saves the figure
-# trainable_str = "trainable" if IS_TRAINABLE else "static"
-# fp_str = "fp16" if FP16 else "fp32"
-# scitex.io.save(
-# fig, f"pac_with_{trainable_str}_bandpass_{fp_str}.png"
-# )
-
-
-# def run_method_tests():
-# import scitex
-
-# # Test parameters
-# FS = 512
-# T_SEC = 4
-
-# class PACProcessor:
-# @batch_torch_fn
-# def process_pac(self, x, fs, **kwargs):
-# return pac(x, fs, **kwargs)
-
-# @signal_fn
-# def process_signal(self, x):
-# return x * 2
-
-# def run_method_basic_tests():
-# processor = PACProcessor()
-
-# # Generate test signal
-# xx, tt, fs = scitex.dsp.demo_sig(
-# batch_size=1, n_chs=1, fs=FS, t_sec=T_SEC, sig_type="pac"
-# )
-
-# try:
-# # Test method with batch processing
-# result_batch, pha_mids, amp_mids = processor.process_pac(
-# xx, fs, pha_n_bands=50, amp_n_bands=30, batch_size=1
-# )
-# assert torch.is_tensor(result_batch)
-
-# # Test basic torch method
-# result_torch = processor.process_signal(xx)
-# assert torch.is_tensor(result_torch)
-
-# scitex.str.printc("Passed: Basic method tests", "yellow")
-# except Exception as err:
-# scitex.str.printc(f"Failed: Basic method tests - {str(err)}", "red")
-
-# def run_method_cuda_tests():
-# if not torch.cuda.is_available():
-# scitex.str.printc(
-# "CUDA method tests skipped: No GPU available", "yellow"
-# )
-# return
-
-# processor = PACProcessor()
-# xx, tt, fs = scitex.dsp.demo_sig(
-# batch_size=1, n_chs=1, fs=FS, t_sec=T_SEC, sig_type="pac"
-# )
-
-# try:
-# # Test with CUDA
-# result_cuda, _, _ = processor.process_pac(xx, fs, device="cuda")
-# assert result_cuda.device.type == "cuda"
-
-# result_torch = processor.process_signal(xx, device="cuda")
-# assert result_torch.device.type == "cuda"
-
-# scitex.str.printc("Passed: CUDA method tests", "yellow")
-# except Exception as err:
-# scitex.str.printc(f"Failed: CUDA method tests - {str(err)}", "red")
-
-# def run_method_batch_size_tests():
-# processor = PACProcessor()
-# batch_sizes = [1, 2, 4]
-
-# for batch_size in batch_sizes:
-# try:
-# xx, tt, fs = scitex.dsp.demo_sig(
-# batch_size=batch_size,
-# n_chs=1,
-# fs=FS,
-# t_sec=T_SEC,
-# sig_type="pac",
-# )
-
-# result, _, _ = processor.process_pac(
-# xx, fs, batch_size=batch_size
-# )
-# assert result.shape[0] == batch_size
-
-# scitex.str.printc(
-# f"Passed: Method batch size test with size={batch_size}",
-# "yellow",
-# )
-# except Exception as err:
-# scitex.str.printc(
-# f"Failed: Method batch size test with size={batch_size} - {str(err)}",
-# "red",
-# )
-
-# # Execute method test suites
-# test_suites = [
-# ("Method Basic Tests", run_method_basic_tests),
-# ("Method CUDA Tests", run_method_cuda_tests),
-# ("Method Batch Size Tests", run_method_batch_size_tests),
-# ]
-
-# for test_name, test_func in test_suites:
-# test_func()
-
-
-# if __name__ == "__main__":
-# run_method_tests()
-
-# # EOF
-
-# """
-# python -m scitex.dsp._pac
-# """
-
-#
-
-# EOF
diff --git a/src/scitex/dsp/_psd.py b/src/scitex/dsp/_psd.py
deleted file mode 100755
index 0e92f33e0..000000000
--- a/src/scitex/dsp/_psd.py
+++ /dev/null
@@ -1,114 +0,0 @@
-#!/usr/bin/env python3
-# Time-stamp: "2024-11-04 02:11:25 (ywatanabe)"
-# File: ./scitex_repo/src/scitex/dsp/_psd.py
-
-"""This script does XYZ."""
-
-try:
- import torch
-
- TORCH_AVAILABLE = True
-except ImportError:
- TORCH_AVAILABLE = False
- torch = None
-
-from scitex.decorators import signal_fn
-
-if TORCH_AVAILABLE:
- from scitex.nn._PSD import PSD
-
-
-@signal_fn
-def psd(
- x,
- fs,
- prob=False,
- dim=-1,
-):
- """
- import matplotlib.pyplot as plt
-
- x, t, fs = scitex.dsp.demo_sig() # (batch_size, n_chs, seq_len)
- pp, ff = psd(x, fs)
-
- # Plots
- plt, CC = scitex.plt.configure_mpl(plt)
- fig, ax = scitex.plt.subplots()
- ax.plot(fs, pp[0, 0])
- ax.xlabel("Frequency [Hz]")
- ax.ylabel("log(Power [uV^2 / Hz]) [a.u.]")
- plt.show()
- """
- if not TORCH_AVAILABLE:
- raise ImportError(
- "PyTorch is not installed. Please install with: pip install torch"
- )
- psd, freqs = PSD(fs, prob=prob, dim=dim)(x)
- return psd, freqs
-
-
-def band_powers(self, psd):
- """
- Calculate the average power for specified frequency bands.
- """
- assert len(self.low_freqs) == len(self.high_freqs)
-
- out = []
- for ll, hh in zip(self.low_freqs, self.high_freqs):
- band_indices = torch.where((freqs >= ll) & (freqs <= hh))[0].to(psd.device)
- band_power = psd[..., band_indices].sum(dim=self.dim)
- bandwidth = hh - ll
- avg_band_power = band_power / bandwidth
- out.append(avg_band_power)
- out = torch.stack(out, dim=-1)
- return out
-
- # Average Power in Each Frequency Band
- avg_band_powers = self.calc_band_avg_power(psd, freqs)
- return (avg_band_powers,)
-
-
-if __name__ == "__main__":
- import sys
-
- import matplotlib.pyplot as plt
-
- import scitex
-
- # Start
- CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt)
-
- # Parameters
- SIG_TYPE = "chirp"
-
- # Demo signal
- xx, tt, fs = scitex.dsp.demo_sig(SIG_TYPE) # (8, 19, 384)
-
- # PSD calculation
- pp, ff = psd(xx, fs, prob=True)
-
- # Plots
- fig, axes = scitex.plt.subplots(nrows=2)
-
- axes[0].plot(tt, xx[0, 0], label=SIG_TYPE)
- axes[1].set_title("Signal")
- axes[0].set_xlabel("Time [s]")
- axes[0].set_ylabel("Amplitude [?V]")
-
- axes[1].plot(ff, pp[0, 0])
- axes[1].set_title("PSD (power spectrum density)")
- axes[1].set_xlabel("Frequency [Hz]")
- axes[1].set_ylabel("Log(Power [?V^2 / Hz]) [a.u.]")
-
- scitex.io.save(fig, "psd.png")
-
- # Close
- scitex.session.close(CONFIG)
-
-# EOF
-
-"""
-/home/ywatanabe/proj/entrance/scitex/dsp/_psd.py
-"""
-
-# EOF
diff --git a/src/scitex/dsp/_resample.py b/src/scitex/dsp/_resample.py
deleted file mode 100755
index 14f5540e7..000000000
--- a/src/scitex/dsp/_resample.py
+++ /dev/null
@@ -1,86 +0,0 @@
-#!/usr/bin/env python3
-# Time-stamp: "2024-04-13 02:35:11 (ywatanabe)"
-
-try:
- import torch
-
- TORCH_AVAILABLE = True
-except ImportError:
- TORCH_AVAILABLE = False
- torch = None
-
-try:
- import torchaudio.transforms as T
-
- TORCHAUDIO_AVAILABLE = True
-except ImportError:
- TORCHAUDIO_AVAILABLE = False
- T = None
-
-import scitex
-from scitex.decorators import signal_fn
-
-
-@signal_fn
-def resample(x, src_fs, tgt_fs, t=None):
- if not TORCH_AVAILABLE:
- raise ImportError(
- "PyTorch is not installed. Please install with: pip install torch"
- )
- if not TORCHAUDIO_AVAILABLE:
- raise ImportError(
- "torchaudio is not installed. Please install with: pip install torchaudio"
- )
- xr = T.Resample(src_fs, tgt_fs, dtype=x.dtype).to(x.device)(x)
- if t is None:
- return xr
- if t is not None:
- tr = torch.linspace(t[0], t[-1], xr.shape[-1])
- return xr, tr
-
-
-if __name__ == "__main__":
- import sys
-
- import matplotlib.pyplot as plt
-
- # Start
- CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt)
-
- # Parameters
- T_SEC = 1
- SIG_TYPE = "chirp"
- SRC_FS = 128
- TGT_FS_UP = 256
- TGT_FS_DOWN = 64
- FREQS_HZ = [10, 30, 100, 300]
-
- # Demo Signal
- xx, tt, fs = scitex.dsp.demo_sig(
- t_sec=T_SEC, fs=SRC_FS, freqs_hz=FREQS_HZ, sig_type=SIG_TYPE
- )
-
- # Resampling
- xd, td = scitex.dsp.resample(xx, fs, TGT_FS_DOWN, t=tt)
- xu, tu = scitex.dsp.resample(xx, fs, TGT_FS_UP, t=tt)
-
- # Plots
- i_batch, i_ch = 0, 0
- fig, axes = plt.subplots(nrows=3, sharex=True, sharey=True)
- axes[0].plot(tt, xx[i_batch, i_ch], label=f"Original ({SRC_FS} Hz)")
- axes[1].plot(td, xd[i_batch, i_ch], label=f"Down-sampled ({TGT_FS_DOWN} Hz)")
- axes[2].plot(tu, xu[i_batch, i_ch], label=f"Up-sampled ({TGT_FS_UP} Hz)")
- for ax in axes:
- ax.legend(loc="upper left")
-
- axes[-1].set_xlabel("Time [s]")
- fig.supylabel("Amplitude [?V]")
- fig.suptitle("Resampling")
- scitex.io.save(fig, "traces.png")
- # plt.show()
-
-# EOF
-
-"""
-/home/ywatanabe/proj/entrance/scitex/dsp/_resample.py
-"""
diff --git a/src/scitex/dsp/_skills/SKILL.md b/src/scitex/dsp/_skills/SKILL.md
deleted file mode 100644
index ad3348a91..000000000
--- a/src/scitex/dsp/_skills/SKILL.md
+++ /dev/null
@@ -1,56 +0,0 @@
----
-name: stx.dsp
-description: Digital signal processing for neuroscience — filtering, spectral analysis, phase-amplitude coupling, ripple detection, wavelets, and resampling.
----
-
-# stx.dsp — Skill Index
-
-Digital signal processing (DSP) utilities for neuroscience and time-series analysis. All major functions accept NumPy arrays, PyTorch tensors, or pandas DataFrames via the `@signal_fn` decorator and return the same type as input.
-
-## Sub-skills
-
-| File | Feature Area |
-|------|-------------|
-| [filtering.md](filtering.md) | Bandpass, bandstop, lowpass, highpass, Gaussian filters |
-| [spectral.md](spectral.md) | Power spectral density and band power extraction |
-| [hilbert.md](hilbert.md) | Analytic signal: amplitude envelope and instantaneous phase |
-| [pac.md](pac.md) | Phase-amplitude coupling (`pac`, `modulation_index`) |
-| [ripple-detection.md](ripple-detection.md) | Hippocampal sharp-wave ripple detection |
-| [wavelet.md](wavelet.md) | Continuous wavelet transform |
-| [resampling.md](resampling.md) | Anti-aliased up/down resampling |
-| [segmentation.md](segmentation.md) | Sliding-window segmentation and sktime conversion |
-| [noise.md](noise.md) | Add Gaussian, white, pink, or brown noise |
-| [normalization.md](normalization.md) | Z-score and min-max normalization |
-| [referencing.md](referencing.md) | Common-average, random, and target re-referencing |
-| [demo-signal.md](demo-signal.md) | Synthetic signal generation for testing |
-| [params.md](params.md) | Built-in EEG frequency bands and electrode montages |
-| [utils.md](utils.md) | Helpers: zero-padding, FIR filter design, differentiable bandpass filters |
-
-## Quick Start
-
-```python
-import scitex as stx
-import numpy as np
-
-# Generate demo signal: shape (batch=8, chs=19, time=2048)
-xx, tt, fs = stx.dsp.demo_sig(sig_type="chirp", fs=512, t_sec=4)
-
-# Bandpass filter 8-30 Hz
-xx_bp = stx.dsp.filt.bandpass(xx, fs, np.array([[8, 30]]))
-
-# Power spectral density
-psd_vals, freqs = stx.dsp.psd(xx, fs)
-
-# Wavelet transform -> phase, amplitude, frequency axis
-pha, amp, freqs_w = stx.dsp.wavelet(xx, fs)
-```
-
-## Optional Dependencies
-
-| Feature | Requires | Install |
-|---------|----------|---------|
-| Filters, PSD, PAC, wavelet, resample | `torch`, `torchaudio` | `pip install torch torchaudio` |
-| Audio device listing | `sounddevice`, PortAudio | `pip install sounddevice` + `apt install portaudio19-dev` |
-| EEG electrode positions | `mne` | `pip install mne` |
-| Ripple demo signal | `ripple_detection` | `pip install ripple_detection` |
-| Tensorpac demo / PAC comparison | `tensorpac` | `pip install tensorpac` |
diff --git a/src/scitex/dsp/_skills/demo-signal.md b/src/scitex/dsp/_skills/demo-signal.md
deleted file mode 100644
index 3039714f7..000000000
--- a/src/scitex/dsp/_skills/demo-signal.md
+++ /dev/null
@@ -1,105 +0,0 @@
----
-description: Generate synthetic signals for testing — periodic, chirp, ripple, Gaussian, PAC, MEG.
----
-
-# stx.dsp.demo_sig — Demo Signal Generation
-
-Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_demo_sig.py`
-
-## Signature
-
-```python
-xx, tt, fs = stx.dsp.demo_sig(
- sig_type="periodic",
- batch_size=8,
- n_chs=19,
- n_segments=20,
- t_sec=4,
- fs=512,
- freqs_hz=None,
- verbose=False,
-)
-```
-
-### Parameters
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `sig_type` | str | `"periodic"` | Signal type (see table below) |
-| `batch_size` | int | `8` | Number of batches |
-| `n_chs` | int | `19` | Number of channels |
-| `n_segments` | int | `20` | Segments (only for `"tensorpac"` and `"pac"`) |
-| `t_sec` | float | `4` | Duration in seconds |
-| `fs` | int | `512` | Sampling frequency in Hz |
-| `freqs_hz` | list or None | `None` | Frequencies for periodic signal; random if `None` |
-| `verbose` | bool | `False` | Print frequency information |
-
-### Returns
-
-- `xx`: signal array, shape depends on `sig_type` (see table)
-- `tt`: time vector, shape `(t_sec * fs,)`
-- `fs`: sampling frequency (same as input)
-
-## Signal types
-
-| `sig_type` | Output shape | Requires | Description |
-|------------|-------------|----------|-------------|
-| `"uniform"` | `(batch, chs, time)` | — | Uniform random in `[-0.5, 0.5]` |
-| `"gauss"` | `(batch, chs, time)` | — | Standard Gaussian noise |
-| `"periodic"` | `(batch, chs, time)` | — | Sum of sine waves at `freqs_hz` |
-| `"chirp"` | `(batch, chs, time)` | — | Linear frequency sweep with AM envelope |
-| `"ripple"` | `(batch, chs, time)` | `ripple_detection` | Simulated hippocampal LFP with ripples |
-| `"meg"` | `(batch, chs, time)` | `mne` | Real MEG segment from MNE sample dataset |
-| `"tensorpac"` | `(batch, chs, segs, time)` | `tensorpac` | PAC signal via `pac_signals_wavelet` |
-| `"pac"` | `(batch, chs, segs, time)` | — | Synthetic PAC: theta phase modulating gamma amplitude |
-
-## Examples
-
-```python
-import scitex as stx
-
-# Default: 8 batches, 19 channels, 4s at 512 Hz
-xx, tt, fs = stx.dsp.demo_sig()
-print(xx.shape) # (8, 19, 2048)
-
-# Chirp (frequency sweep)
-xx, tt, fs = stx.dsp.demo_sig(sig_type="chirp", fs=512, t_sec=2)
-
-# Periodic with specific frequencies
-xx, tt, fs = stx.dsp.demo_sig(
- sig_type="periodic",
- freqs_hz=[10, 30, 100, 300],
- fs=1024,
- t_sec=1,
- batch_size=4,
- n_chs=2,
-)
-
-# PAC signal with segment dimension (for modulation_index)
-xx, tt, fs = stx.dsp.demo_sig(sig_type="pac", n_segments=20, fs=512, t_sec=4)
-print(xx.shape) # (8, 19, 20, 2048)
-
-# PAC signal using tensorpac wavelet method
-xx, tt, fs = stx.dsp.demo_sig(sig_type="tensorpac", n_segments=20)
-
-# Ripple simulation
-xx, tt, fs = stx.dsp.demo_sig(sig_type="ripple", fs=1000, t_sec=10)
-```
-
-## Internal signal constructors
-
-These private functions can be called directly for 1D signals:
-
-```python
-from scitex.dsp._demo_sig import (
- _demo_sig_periodic_1d,
- _demo_sig_chirp_1d,
- _demo_sig_ripple_1d,
-)
-
-# Single channel periodic
-sig_1d = _demo_sig_periodic_1d(t_sec=2, fs=512, freqs_hz=[10, 40])
-
-# Single channel chirp
-sig_chirp = _demo_sig_chirp_1d(t_sec=2, fs=512, low_hz=5, high_hz=200)
-```
diff --git a/src/scitex/dsp/_skills/filtering.md b/src/scitex/dsp/_skills/filtering.md
deleted file mode 100644
index db5f3c711..000000000
--- a/src/scitex/dsp/_skills/filtering.md
+++ /dev/null
@@ -1,110 +0,0 @@
----
-description: Bandpass, bandstop, lowpass, highpass, and Gaussian filters for multi-channel signals.
----
-
-# stx.dsp.filt — Filtering
-
-Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/filt.py`
-
-All filter functions are decorated with `@signal_fn`, which means they accept NumPy arrays, PyTorch tensors, or pandas DataFrames and return the same type. Input shape must be `(batch_size, n_chs, seq_len)` or compatible broadcastable form.
-
-Filters are implemented as PyTorch neural network modules from `scitex.nn._Filters`. They require `torch`.
-
-## Function Signatures
-
-```python
-stx.dsp.filt.bandpass(x, fs, bands, t=None)
-stx.dsp.filt.bandstop(x, fs, bands, t=None)
-stx.dsp.filt.lowpass(x, fs, cutoffs_hz, t=None)
-stx.dsp.filt.highpass(x, fs, cutoffs_hz, t=None)
-stx.dsp.filt.gauss(x, sigma, t=None)
-```
-
-### Parameters
-
-| Parameter | Type | Description |
-|-----------|------|-------------|
-| `x` | ndarray / Tensor | Input signal, shape `(batch, chs, time)` |
-| `fs` | float | Sampling frequency in Hz |
-| `bands` | array `(n_bands, 2)` | `[[low_hz, high_hz], ...]` for bandpass/bandstop |
-| `cutoffs_hz` | array `(n_bands,)` | Cutoff frequencies for lowpass/highpass |
-| `sigma` | float | Gaussian kernel width in samples (standard deviations) |
-| `t` | ndarray or None | Optional time vector; if given, also returned |
-
-### Return values
-
-- Without `t`: filtered signal, same shape and type as input
-- With `t`: `(filtered_signal, time_vector)` — time vector is unchanged
-
-## Examples
-
-```python
-import scitex as stx
-import numpy as np
-
-xx, tt, fs = stx.dsp.demo_sig(sig_type="periodic", fs=1024, t_sec=1)
-# xx.shape: (8, 19, 1024)
-
-# Single band: [[low_hz, high_hz]]
-BANDS = np.array([[80, 310]])
-
-# Bandpass 80-310 Hz
-x_bp = stx.dsp.filt.bandpass(xx, fs, BANDS)
-
-# Bandstop 80-310 Hz (notch)
-x_bs = stx.dsp.filt.bandstop(xx, fs, BANDS)
-
-# Lowpass at 80 Hz
-x_lp = stx.dsp.filt.lowpass(xx, fs, BANDS[:, 0])
-
-# Highpass at 310 Hz
-x_hp = stx.dsp.filt.highpass(xx, fs, BANDS[:, 1])
-
-# Gaussian smoothing (sigma=3 samples)
-x_g = stx.dsp.filt.gauss(xx, sigma=3)
-
-# With time vector returned
-x_bp, t_bp = stx.dsp.filt.bandpass(xx, fs, BANDS, t=tt)
-```
-
-## Multi-band filtering
-
-`bandpass` and `bandstop` accept multiple bands at once. The output gains an extra dimension for each band:
-
-```python
-BANDS = np.array([[4, 8], [8, 13], [13, 30]]) # theta, alpha, beta
-x_multi = stx.dsp.filt.bandpass(xx, fs, BANDS)
-# x_multi.shape: (8, 19, 3, 1024) — extra dim for n_bands
-```
-
-## EEG use case
-
-```python
-# Ripple band detection preprocessing
-ripple_bands = np.array([[80, 140]])
-x_ripple = stx.dsp.filt.bandpass(xx, fs, ripple_bands)
-
-# Theta band for PAC phase
-theta_bands = np.array([[4, 8]])
-x_theta = stx.dsp.filt.bandpass(xx, fs, theta_bands)
-```
-
-## FIR filter design utility
-
-`stx.dsp.utils.filter.design_filter` exposes the underlying FIR design for inspection:
-
-```python
-from scitex.dsp.utils.filter import design_filter, plot_filter_responses
-
-xx, tt, fs = stx.dsp.demo_sig()
-seq_len = xx.shape[-1]
-
-# Returns filter coefficients (numpy array)
-bp_filter = design_filter(seq_len, fs, low_hz=30, high_hz=70)
-lp_filter = design_filter(seq_len, fs, low_hz=30)
-hp_filter = design_filter(seq_len, fs, high_hz=70)
-bs_filter = design_filter(seq_len, fs, low_hz=30, high_hz=70, is_bandstop=True)
-
-# Plot impulse + frequency response
-fig = plot_filter_responses(bp_filter, fs, title="Bandpass 30-70 Hz")
-```
diff --git a/src/scitex/dsp/_skills/hilbert.md b/src/scitex/dsp/_skills/hilbert.md
deleted file mode 100644
index 8004c2728..000000000
--- a/src/scitex/dsp/_skills/hilbert.md
+++ /dev/null
@@ -1,72 +0,0 @@
----
-description: Hilbert transform returning instantaneous phase and amplitude envelope.
----
-
-# stx.dsp.hilbert — Analytic Signal
-
-Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_hilbert.py`
-
-## Signature
-
-```python
-phase, amplitude = stx.dsp.hilbert(x, dim=-1)
-```
-
-Computes the analytic signal via the Hilbert transform using `scitex.nn._Hilbert`. Both outputs have the same shape as `x`.
-
-### Parameters
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `x` | ndarray / Tensor | required | Input signal, shape `(batch, chs, time)` |
-| `dim` | int | `-1` | Dimension along which to apply the transform |
-
-### Returns
-
-- `phase`: instantaneous phase in radians, same shape as `x`
-- `amplitude`: amplitude envelope (always non-negative), same shape as `x`
-
-Decorated with `@signal_fn`: accepts NumPy arrays, PyTorch tensors, or DataFrames; returns the same type.
-
-## Examples
-
-```python
-import scitex as stx
-
-xx, tt, fs = stx.dsp.demo_sig(sig_type="chirp", t_sec=1.0, fs=400)
-# xx.shape: (8, 19, 400)
-
-phase, amplitude = stx.dsp.hilbert(xx)
-# phase.shape: (8, 19, 400) — values in [-pi, pi]
-# amplitude.shape: (8, 19, 400) — non-negative envelope
-
-# Plot signal and envelope for first batch/channel
-import matplotlib.pyplot as plt
-fig, axes = plt.subplots(2, 1, sharex=True)
-axes[0].plot(tt, xx[0, 0], label="signal")
-axes[0].plot(tt, amplitude[0, 0], label="envelope")
-axes[0].legend()
-axes[1].plot(tt, phase[0, 0], label="phase [rad]")
-axes[1].legend()
-```
-
-## PAC preprocessing pipeline
-
-`hilbert` is used internally by `pac` and `detect_ripples`, but you can also call it directly for custom phase-amplitude analyses:
-
-```python
-import scitex as stx
-import numpy as np
-
-xx, tt, fs = stx.dsp.demo_sig(fs=512, t_sec=4)
-
-# Extract theta phase
-theta = stx.dsp.filt.bandpass(xx, fs, np.array([[4, 8]]))
-theta_phase, _ = stx.dsp.hilbert(theta)
-
-# Extract gamma amplitude
-gamma = stx.dsp.filt.bandpass(xx, fs, np.array([[32, 80]]))
-_, gamma_amp = stx.dsp.hilbert(gamma)
-
-# theta_phase and gamma_amp are ready for coupling analysis
-```
diff --git a/src/scitex/dsp/_skills/noise.md b/src/scitex/dsp/_skills/noise.md
deleted file mode 100644
index 9eacce2bd..000000000
--- a/src/scitex/dsp/_skills/noise.md
+++ /dev/null
@@ -1,72 +0,0 @@
----
-description: Add Gaussian, white, pink, or brown noise to signals.
----
-
-# stx.dsp.add_noise — Noise Addition
-
-Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/add_noise.py`
-
-All functions are decorated with `@signal_fn`, accept NumPy arrays or PyTorch tensors, and return the same type. Requires `torch`.
-
-## Functions
-
-```python
-noisy = stx.dsp.add_noise.gauss(x, amp=1.0)
-noisy = stx.dsp.add_noise.white(x, amp=1.0)
-noisy = stx.dsp.add_noise.pink(x, amp=1.0, dim=-1)
-noisy = stx.dsp.add_noise.brown(x, amp=1.0, dim=-1)
-```
-
-### Parameters
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `x` | ndarray / Tensor | required | Input signal |
-| `amp` | float | `1.0` | Noise amplitude (scale factor) |
-| `dim` | int | `-1` | Dimension along which to generate correlated noise (pink, brown only) |
-
-### Returns
-
-Signal with noise added, same shape and type as input.
-
-## Noise types
-
-| Function | Type | Spectrum |
-|----------|------|---------|
-| `gauss` | Gaussian | White (flat), samples from `N(0, amp)` |
-| `white` | Uniform | White (flat), samples from `U(-amp, amp)` |
-| `pink` | 1/f noise | Pink spectrum: power `~ 1/f` |
-| `brown` | Brownian | Red spectrum: cumulative sum of uniform, then min-max normalized |
-
-## Examples
-
-```python
-import scitex as stx
-
-xx, tt, fs = stx.dsp.demo_sig(fs=128, t_sec=1)
-
-# Add different noise types
-xx_gauss = stx.dsp.add_noise.gauss(xx, amp=0.5)
-xx_white = stx.dsp.add_noise.white(xx, amp=0.5)
-xx_pink = stx.dsp.add_noise.pink(xx, amp=0.5)
-xx_brown = stx.dsp.add_noise.brown(xx, amp=0.5)
-
-# Inspect noise alone
-noise = stx.dsp.add_noise.pink(xx, amp=1.0) - xx
-```
-
-## Data augmentation use case
-
-```python
-import scitex as stx
-import numpy as np
-
-xx, tt, fs = stx.dsp.demo_sig(batch_size=8, fs=256, t_sec=2)
-
-# Augment by mixing noise types with different amplitudes
-xx_aug_1 = stx.dsp.add_noise.gauss(xx, amp=0.1) # mild Gaussian
-xx_aug_2 = stx.dsp.add_noise.pink(xx, amp=0.3) # realistic 1/f noise
-
-augmented = np.concatenate([xx, xx_aug_1, xx_aug_2], axis=0)
-# augmented.shape: (24, 19, 512)
-```
diff --git a/src/scitex/dsp/_skills/normalization.md b/src/scitex/dsp/_skills/normalization.md
deleted file mode 100644
index 332593749..000000000
--- a/src/scitex/dsp/_skills/normalization.md
+++ /dev/null
@@ -1,78 +0,0 @@
----
-description: Z-score and min-max normalization for multi-channel signals.
----
-
-# stx.dsp.norm — Normalization
-
-Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/norm.py`
-
-Both functions are decorated with `@signal_fn`, accept NumPy arrays or PyTorch tensors, and return the same type. Requires `torch`.
-
-## Functions
-
-```python
-x_z = stx.dsp.norm.z(x, dim=-1)
-x_mm = stx.dsp.norm.minmax(x, amp=1.0, dim=-1, fn="mean")
-```
-
-### stx.dsp.norm.z
-
-Z-score normalization: `(x - mean) / std` along `dim`.
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `x` | ndarray / Tensor | required | Input signal |
-| `dim` | int | `-1` | Dimension to normalize along |
-
-Returns tensor with mean 0 and std 1 along `dim`.
-
-### stx.dsp.norm.minmax
-
-Min-max normalization scaled to `[-amp, amp]` using the max absolute value.
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `x` | ndarray / Tensor | required | Input signal |
-| `amp` | float | `1.0` | Output amplitude scale (result bounded by `[-amp, amp]`) |
-| `dim` | int | `-1` | Dimension to normalize along |
-| `fn` | str | `"mean"` | Unused (present for API compatibility) |
-
-Implementation: divides by `max(|max|, |min|)` so the output fits in `[-amp, amp]`.
-
-## Examples
-
-```python
-import scitex as stx
-
-xx, tt, fs = stx.dsp.demo_sig(fs=256, t_sec=2)
-# xx.shape: (8, 19, 512)
-
-# Z-score normalize along time dimension
-xx_z = stx.dsp.norm.z(xx, dim=-1)
-# Each channel has mean ~0 and std ~1
-
-# Min-max normalize to [-1, 1] range
-xx_mm = stx.dsp.norm.minmax(xx, amp=1.0, dim=-1)
-
-# Normalize along channel dimension (across channels per time point)
-xx_ch = stx.dsp.norm.z(xx, dim=1)
-
-# Normalize each batch independently
-xx_b = stx.dsp.norm.z(xx, dim=-1)
-```
-
-## Common use cases
-
-```python
-import scitex as stx
-
-xx, tt, fs = stx.dsp.demo_sig()
-
-# Pre-normalize before bandpass filtering
-xx_norm = stx.dsp.norm.z(xx)
-xx_filt = stx.dsp.filt.bandpass(xx_norm, fs, [[4, 8]])
-
-# Normalize after wavelet transform for visualization
-pha, amp, freqs = stx.dsp.wavelet(xx, fs)
-amp_norm = stx.dsp.norm.minmax(amp, amp=1.0)
-```
diff --git a/src/scitex/dsp/_skills/pac.md b/src/scitex/dsp/_skills/pac.md
deleted file mode 100644
index dc7f51917..000000000
--- a/src/scitex/dsp/_skills/pac.md
+++ /dev/null
@@ -1,154 +0,0 @@
----
-description: Phase-amplitude coupling (PAC) via GPU-accelerated bandpass filtering and modulation index.
----
-
-# stx.dsp — Phase-Amplitude Coupling
-
-Sources:
-- `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_pac.py`
-- `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_modulation_index.py`
-
-Both functions require `torch`. `pac` additionally benefits from CUDA.
-
-## stx.dsp.pac
-
-High-level end-to-end PAC from raw signal.
-
-```python
-pac_vals, pha_mids_hz, amp_mids_hz = stx.dsp.pac(
- x,
- fs,
- pha_start_hz=2,
- pha_end_hz=20,
- pha_n_bands=100,
- amp_start_hz=60,
- amp_end_hz=160,
- amp_n_bands=100,
- device="cuda",
- batch_size=1,
- batch_size_ch=-1,
- fp16=False,
- trainable=False,
- n_perm=None,
- amp_prob=False,
-)
-```
-
-### Parameters
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `x` | ndarray / Tensor | required | Signal, shape `(batch, chs, time)` or `(batch, chs, segments, time)` |
-| `fs` | float | required | Sampling frequency in Hz |
-| `pha_start_hz` | float | `2` | Start of phase frequency range |
-| `pha_end_hz` | float | `20` | End of phase frequency range |
-| `pha_n_bands` | int | `100` | Number of phase bands to compute |
-| `amp_start_hz` | float | `60` | Start of amplitude frequency range |
-| `amp_end_hz` | float | `160` | End of amplitude frequency range |
-| `amp_n_bands` | int | `100` | Number of amplitude bands to compute |
-| `device` | str | `"cuda"` | PyTorch device, falls back to CPU if no GPU |
-| `batch_size` | int | `1` | Batch size for processing |
-| `batch_size_ch` | int | `-1` | Channel batch size; `-1` processes all at once |
-| `fp16` | bool | `False` | Use half-precision for memory efficiency |
-| `trainable` | bool | `False` | Make filter bank parameters learnable (gradient flows) |
-| `n_perm` | int or None | `None` | Number of permutations for surrogate testing |
-| `amp_prob` | bool | `False` | Normalize amplitude to probability distribution |
-
-### Returns
-
-- `pac_vals`: PAC values, shape `(batch, chs, pha_n_bands, amp_n_bands)`
-- `pha_mids_hz`: center frequencies of phase bands, shape `(pha_n_bands,)`
-- `amp_mids_hz`: center frequencies of amplitude bands, shape `(amp_n_bands,)`
-
-## stx.dsp.modulation_index
-
-Lower-level function: compute PAC from pre-computed phase and amplitude arrays.
-
-```python
-mi = stx.dsp.modulation_index(pha, amp, n_bins=18, amp_prob=False)
-```
-
-### Parameters
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `pha` | Tensor | required | Phase signal, shape `(batch, chs, n_freqs_pha, n_segments, seq_len)` |
-| `amp` | Tensor | required | Amplitude signal, shape `(batch, chs, n_freqs_amp, n_segments, seq_len)` |
-| `n_bins` | int | `18` | Number of phase bins for the mean-vector length computation |
-| `amp_prob` | bool | `False` | Normalize amplitude distribution |
-
-### Returns
-
-- `mi`: modulation index values
-
-## Helper: \_reshape
-
-```python
-reshaped = stx.dsp._reshape(x, batch_size=2, n_chs=4)
-```
-
-Utility to reshape a raw PAC tensor `x` into `(batch, chs, ...)` format for `modulation_index`.
-
-## Examples
-
-### End-to-end PAC
-
-```python
-import scitex as stx
-
-FS = 512
-xx, tt, fs = stx.dsp.demo_sig(
- batch_size=1, n_chs=1, fs=FS, t_sec=4, sig_type="tensorpac"
-)
-
-pac_vals, pha_mids, amp_mids = stx.dsp.pac(
- xx, fs,
- pha_start_hz=2, pha_end_hz=20, pha_n_bands=50,
- amp_start_hz=60, amp_end_hz=160, amp_n_bands=30,
-)
-# pac_vals.shape: (1, 1, 50, 30)
-
-# Plot comodulogram
-import matplotlib.pyplot as plt
-fig, ax = plt.subplots()
-im = ax.imshow(pac_vals[0, 0].T, origin="lower", aspect="auto")
-ax.set_xlabel("Phase frequency [Hz]")
-ax.set_ylabel("Amplitude frequency [Hz]")
-plt.colorbar(im, label="PAC (MI)")
-```
-
-### Memory-efficient channel batching
-
-```python
-pac_vals, pha_mids, amp_mids = stx.dsp.pac(
- xx, fs,
- batch_size_ch=4, # process 4 channels at a time
- fp16=True, # halve memory usage
-)
-```
-
-### Trainable filter banks (gradient-based optimization)
-
-```python
-# Filters become nn.Parameters; gradients flow through
-pac_vals, pha_mids, amp_mids = stx.dsp.pac(
- xx, fs, trainable=True
-)
-pac_vals.sum().backward() # gradients available on pha_mids / amp_mids
-```
-
-### Compare with Tensorpac
-
-```python
-from scitex.dsp.utils.pac import calc_pac_with_tensorpac, plot_pac_scitex_vs_tensorpac
-
-# Tensorpac reference calculation
-phases, amplitudes, freqs_pha, freqs_amp, pac_tp = calc_pac_with_tensorpac(
- xx, fs, t_sec=4, i_batch=0, i_ch=0
-)
-
-# Plot side-by-side comparison
-fig = plot_pac_scitex_vs_tensorpac(
- pac_vals[0, 0], pac_tp, freqs_pha, freqs_amp
-)
-```
diff --git a/src/scitex/dsp/_skills/params.md b/src/scitex/dsp/_skills/params.md
deleted file mode 100644
index 0832b2dd3..000000000
--- a/src/scitex/dsp/_skills/params.md
+++ /dev/null
@@ -1,103 +0,0 @@
----
-description: Built-in EEG frequency bands and standard electrode montages.
----
-
-# stx.dsp.params — Parameters and Constants
-
-Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/params.py`
-
-## stx.dsp.params.BANDS
-
-Standard EEG frequency bands as a pandas DataFrame.
-
-```python
-import scitex as stx
-
-print(stx.dsp.params.BANDS)
-# delta theta lalpha halpha beta gamma
-# low_hz 0.5 4.0 8.0 10.0 13.0 32.0
-# high_hz 4.0 8.0 10.0 13.0 32.0 75.0
-```
-
-### Access patterns
-
-```python
-bands = stx.dsp.params.BANDS
-
-# Get a single band's range
-delta_low = bands["delta"]["low_hz"] # 0.5
-delta_high = bands["delta"]["high_hz"] # 4.0
-
-# Get as array for use with stx.dsp.filt.bandpass
-import numpy as np
-
-# All bands stacked
-all_bands = bands.values.T # shape (6, 2): [[low1, high1], ...]
-
-# Single band
-theta_band = np.array([[bands["theta"]["low_hz"], bands["theta"]["high_hz"]]])
-# [[4.0, 8.0]]
-
-# Multiple specific bands
-gamma_beta = np.array([
- [bands["beta"]["low_hz"], bands["beta"]["high_hz"]],
- [bands["gamma"]["low_hz"], bands["gamma"]["high_hz"]],
-])
-```
-
-### Use with filtering
-
-```python
-import scitex as stx
-import numpy as np
-
-xx, tt, fs = stx.dsp.demo_sig(fs=256, t_sec=4)
-
-bands = stx.dsp.params.BANDS
-
-# Bandpass filter to theta (4-8 Hz)
-theta_band = np.array([[bands["theta"]["low_hz"], bands["theta"]["high_hz"]]])
-xx_theta = stx.dsp.filt.bandpass(xx, fs, theta_band)
-
-# Filter to all standard bands at once
-all_bands_array = bands.values.T # (6, 2)
-xx_all_bands = stx.dsp.filt.bandpass(xx, fs, all_bands_array)
-# xx_all_bands.shape: (batch, chs, 6, time)
-```
-
-## stx.dsp.params.EEG_MONTAGE_1020
-
-Standard 10-20 EEG electrode names (19 electrodes).
-
-```python
-stx.dsp.params.EEG_MONTAGE_1020
-# ['FP1', 'F3', 'C3', 'P3', 'O1',
-# 'FP2', 'F4', 'C4', 'P4', 'O2',
-# 'F7', 'T7', 'P7', 'F8', 'T8', 'P8',
-# 'FZ', 'CZ', 'PZ']
-```
-
-Used as default in `stx.dsp.get_eeg_pos()`.
-
-## stx.dsp.params.EEG_MONTAGE_BIPOLAR_TRANVERSE
-
-Bipolar transverse montage (14 channel pairs).
-
-```python
-stx.dsp.params.EEG_MONTAGE_BIPOLAR_TRANVERSE
-# ['FP1-FP2', 'F7-F3', 'F3-FZ', 'FZ-F4', 'F4-F8',
-# 'T7-C3', 'C3-CZ', 'CZ-C4', 'C4-T8',
-# 'P7-P3', 'P3-PZ', 'PZ-P4', 'P4-P8',
-# 'O1-O2']
-```
-
-## stx.dsp.get_eeg_pos (optional, requires MNE)
-
-Get 3D electrode positions from the standard 10-20 montage.
-
-```python
-df = stx.dsp.get_eeg_pos(channel_names=stx.dsp.params.EEG_MONTAGE_1020)
-# Returns DataFrame with columns = channel names, rows = [x, y, z]
-```
-
-Raises `ImportError` if `mne` is not installed.
diff --git a/src/scitex/dsp/_skills/referencing.md b/src/scitex/dsp/_skills/referencing.md
deleted file mode 100644
index e83c01acb..000000000
--- a/src/scitex/dsp/_skills/referencing.md
+++ /dev/null
@@ -1,92 +0,0 @@
----
-description: Common-average, random, and target channel re-referencing for EEG/LFP signals.
----
-
-# stx.dsp.reference — Re-referencing
-
-Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/reference.py`
-
-All functions are decorated with `@torch_fn`, accept NumPy arrays or PyTorch tensors, and return the same type. They require `torch`.
-
-Re-referencing is applied along the channel dimension (`dim=-2` by default, i.e., the second-to-last axis in a `(batch, chs, time)` tensor).
-
-## Functions
-
-```python
-re_ref = stx.dsp.reference.common_average(x, dim=-2)
-re_ref = stx.dsp.reference.random(x, dim=-2)
-re_ref = stx.dsp.reference.take_reference(x, tgt_indi, dim=-2)
-```
-
-### stx.dsp.reference.common_average
-
-Subtract the mean across all channels (common average reference), then z-score.
-
-Formula: `(x - mean(x, dim)) / std(x, dim)`
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `x` | ndarray / Tensor | required | Signal `(batch, chs, time)` |
-| `dim` | int | `-2` | Channel dimension |
-
-### stx.dsp.reference.random
-
-Subtract a random permutation of the channel data from the original.
-
-Each call produces a different result (non-deterministic). Useful for data augmentation.
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `x` | ndarray / Tensor | required | Signal `(batch, chs, time)` |
-| `dim` | int | `-2` | Channel dimension |
-
-### stx.dsp.reference.take_reference
-
-Subtract a specific channel (or set of channels) from all channels.
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `x` | ndarray / Tensor | required | Signal `(batch, chs, time)` |
-| `tgt_indi` | int or slice | required | Index/indices of the reference channel(s) |
-| `dim` | int | `-2` | Channel dimension |
-
-## Examples
-
-```python
-import scitex as stx
-
-xx, tt, fs = stx.dsp.demo_sig(n_chs=19, fs=256, t_sec=2)
-# xx.shape: (8, 19, 512)
-
-# Common average reference (standard for EEG)
-xx_car = stx.dsp.reference.common_average(xx)
-assert xx_car.shape == xx.shape
-
-# Random reference (data augmentation)
-xx_rand = stx.dsp.reference.random(xx)
-
-# Reference to channel 0 (linked mastoid, for example)
-xx_ref0 = stx.dsp.reference.take_reference(xx, tgt_indi=0)
-
-# Reference to average of channels 17 and 18 (bilateral mastoids)
-xx_bm = stx.dsp.reference.take_reference(xx, tgt_indi=slice(17, 19))
-```
-
-## EEG montage pipeline
-
-```python
-import scitex as stx
-import numpy as np
-
-# Load raw EEG (assumed shape: batch, chs, time)
-xx, tt, fs = stx.dsp.demo_sig(sig_type="meg", n_chs=19, fs=256)
-
-# 1. Re-reference to common average
-xx = stx.dsp.reference.common_average(xx)
-
-# 2. Bandpass filter to remove DC and high-frequency noise
-xx = stx.dsp.filt.bandpass(xx, fs, np.array([[0.5, 80]]))
-
-# 3. Z-score normalize
-xx = stx.dsp.norm.z(xx)
-```
diff --git a/src/scitex/dsp/_skills/resampling.md b/src/scitex/dsp/_skills/resampling.md
deleted file mode 100644
index fe0b533d2..000000000
--- a/src/scitex/dsp/_skills/resampling.md
+++ /dev/null
@@ -1,83 +0,0 @@
----
-description: Anti-aliased signal resampling using torchaudio.
----
-
-# stx.dsp.resample — Resampling
-
-Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_resample.py`
-
-## Signature
-
-```python
-xr = stx.dsp.resample(x, src_fs, tgt_fs, t=None)
-# or, with time vector:
-xr, tr = stx.dsp.resample(x, src_fs, tgt_fs, t=tt)
-```
-
-Uses `torchaudio.transforms.Resample` for polyphase anti-aliased resampling. Decorated with `@signal_fn`.
-
-Requires `torch` and `torchaudio`.
-
-### Parameters
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `x` | ndarray / Tensor | required | Signal, shape `(batch, chs, time)` |
-| `src_fs` | float | required | Source sampling frequency in Hz |
-| `tgt_fs` | float | required | Target sampling frequency in Hz |
-| `t` | ndarray or None | `None` | Optional time vector; if given, a resampled time vector is also returned |
-
-### Returns
-
-- `xr`: resampled signal, shape `(batch, chs, new_time)` where `new_time = round(time * tgt_fs / src_fs)`
-- If `t` is provided: `(xr, tr)` where `tr` is a new time vector spanning the same range as `t`
-
-## Examples
-
-```python
-import scitex as stx
-
-T_SEC = 1
-SRC_FS = 128
-xx, tt, fs = stx.dsp.demo_sig(sig_type="chirp", t_sec=T_SEC, fs=SRC_FS)
-
-# Downsample to 64 Hz
-xd, td = stx.dsp.resample(xx, fs, 64, t=tt)
-print(f"Original: {xx.shape}, {tt.shape}")
-print(f"Downsampled: {xd.shape}, {td.shape}")
-
-# Upsample to 256 Hz
-xu, tu = stx.dsp.resample(xx, fs, 256, t=tt)
-print(f"Upsampled: {xu.shape}, {tu.shape}")
-
-# Without time vector
-xd = stx.dsp.resample(xx, fs, 64)
-```
-
-## Common use cases
-
-```python
-import scitex as stx
-import numpy as np
-
-xx, tt, fs = stx.dsp.demo_sig(fs=1000, t_sec=10)
-
-# Preprocessing pipeline: downsample before filtering for speed
-xx_256, tt_256 = stx.dsp.resample(xx, fs, 256, t=tt)
-xx_filt = stx.dsp.filt.bandpass(xx_256, 256, np.array([[4, 80]]))
-
-# Resample before ripple detection (detect_ripples does this internally)
-# but you can pre-downsample manually:
-xx_low = stx.dsp.resample(xx, fs, 300) # ripple band is 80-140 Hz; 3x = 420 Hz
-
-# Match sampling rates between two recordings
-xx_a, tt_a, fs_a = stx.dsp.demo_sig(fs=1024)
-xx_b, tt_b, fs_b = stx.dsp.demo_sig(fs=512)
-xx_a_resampled = stx.dsp.resample(xx_a, fs_a, fs_b)
-```
-
-## Notes
-
-- The resampled time vector `tr` is computed with `torch.linspace(t[0], t[-1], new_len)`, preserving the original time span.
-- Resampling uses the dtype of the input tensor; convert to float32 first if needed.
-- `detect_ripples` internally downsamples to `low_hz * 3` Hz automatically.
diff --git a/src/scitex/dsp/_skills/ripple-detection.md b/src/scitex/dsp/_skills/ripple-detection.md
deleted file mode 100644
index 9b9b7ff2f..000000000
--- a/src/scitex/dsp/_skills/ripple-detection.md
+++ /dev/null
@@ -1,122 +0,0 @@
----
-description: Detect hippocampal sharp-wave ripples from wide-band LFP or EEG signals.
----
-
-# stx.dsp.detect_ripples — Ripple Detection
-
-Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_detect_ripples.py`
-
-## Signature
-
-```python
-df = stx.dsp.detect_ripples(
- xx,
- fs,
- low_hz,
- high_hz,
- sd=2.0,
- smoothing_sigma_ms=4,
- min_duration_ms=10,
- return_preprocessed_signal=False,
-)
-```
-
-### Parameters
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `xx` | ndarray | required | Signal, shape `(n_chs, time)` or `(batch, n_chs, time)` |
-| `fs` | float | required | Sampling frequency in Hz |
-| `low_hz` | float | required | Lower edge of ripple band (e.g. `80`) |
-| `high_hz` | float | required | Upper edge of ripple band (e.g. `140`) |
-| `sd` | float | `2.0` | Threshold in standard deviations above mean for peak detection |
-| `smoothing_sigma_ms` | float | `4` | Gaussian smoothing width in milliseconds |
-| `min_duration_ms` | float | `10` | Minimum ripple duration in milliseconds |
-| `return_preprocessed_signal` | bool | `False` | Also return the preprocessed RMS envelope |
-
-### Returns
-
-- If `return_preprocessed_signal=False` (default): pandas `DataFrame`
-- If `return_preprocessed_signal=True`: `(df, xx_r, fs_r)` where `xx_r` is the preprocessed signal and `fs_r` is the downsampled fs
-
-## Output DataFrame Columns
-
-| Column | Type | Description |
-|--------|------|-------------|
-| `start_s` | float | Event start time in seconds |
-| `end_s` | float | Event end time in seconds |
-| `duration_s` | float | Event duration in seconds |
-| `peak_s` | float | Time of peak amplitude in seconds |
-| `rel_peak_pos` | float | Peak position within event, 0.0–1.0 |
-| `peak_amp_sd` | float | Peak amplitude in standard deviations |
-
-The DataFrame index holds the channel index for each detected event.
-
-## Preprocessing pipeline (internal)
-
-1. **Downsample** to `low_hz * 3` Hz to speed up computation
-2. **Common-average subtraction** across channels to reduce EMG artifacts
-3. **Bandpass filter** to `[low_hz, high_hz]`
-4. **RMS envelope**: square, Hilbert amplitude, Gaussian smooth, square-root
-5. **Average across channels**, then **z-score** normalization
-6. **Peak detection** at threshold `sd` standard deviations
-7. **Edge removal**: drop events within `3 / low_hz` seconds of recording edges
-
-## Examples
-
-```python
-import scitex as stx
-
-# Generate synthetic ripple signal (requires ripple_detection package)
-xx, tt, fs = stx.dsp.demo_sig(sig_type="ripple", fs=1000, t_sec=10)
-
-# Detect ripples in 80-140 Hz band
-df = stx.dsp.detect_ripples(xx, fs, low_hz=80, high_hz=140)
-print(df)
-# start_s end_s duration_s peak_s rel_peak_pos peak_amp_sd
-
-# Custom threshold: require 3 SD and 20 ms minimum duration
-df = stx.dsp.detect_ripples(
- xx, fs,
- low_hz=80, high_hz=140,
- sd=3.0,
- min_duration_ms=20,
-)
-
-# Get preprocessed envelope for inspection
-df, xx_r, fs_r = stx.dsp.detect_ripples(
- xx, fs,
- low_hz=80, high_hz=140,
- return_preprocessed_signal=True,
-)
-print(f"Preprocessed fs: {fs_r} Hz, shape: {xx_r.shape}")
-```
-
-## Low-level helper functions
-
-These are exported from `stx.dsp` for advanced use:
-
-```python
-# Preprocessing (bandpass, RMS, z-score)
-xx_r, fs_r = stx.dsp._preprocess(xx, fs, low_hz=80, high_hz=140)
-
-# Event detection from preprocessed signal
-df = stx.dsp._find_events(xx_r, fs_r, sd=2.0, min_duration_ms=10)
-
-# Drop detections near recording edges
-df = stx.dsp._drop_ripples_at_edges(df, low_hz=80, xx_r=xx_r, fs_r=fs_r)
-
-# Add relative peak position column
-df = stx.dsp._calc_relative_peak_position(df)
-
-# Reorder columns into canonical order
-df = stx.dsp._sort_columns(df)
-```
-
-## Typical ripple band parameters
-
-| Region | Low Hz | High Hz |
-|--------|--------|---------|
-| Hippocampus SWR | 80 | 140 |
-| Fast ripples | 200 | 400 |
-| Sleep spindles | 12 | 15 |
diff --git a/src/scitex/dsp/_skills/segmentation.md b/src/scitex/dsp/_skills/segmentation.md
deleted file mode 100644
index 4d3b6a2fd..000000000
--- a/src/scitex/dsp/_skills/segmentation.md
+++ /dev/null
@@ -1,134 +0,0 @@
----
-description: Sliding-window segmentation, signal cropping, and sktime DataFrame conversion.
----
-
-# stx.dsp — Segmentation
-
-Sources:
-- `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_transform.py`
-- `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_crop.py`
-
-## stx.dsp.to_segments
-
-PyTorch-based sliding-window segmentation using `unfold`.
-
-```python
-windows = stx.dsp.to_segments(x, window_size, overlap_factor=1, dim=-1)
-```
-
-Decorated with `@torch_fn`.
-
-### Parameters
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `x` | Tensor / ndarray | required | Input signal |
-| `window_size` | int | required | Number of samples per window |
-| `overlap_factor` | int | `1` | Stride = `window_size // overlap_factor`; `1` = no overlap, `2` = 50% overlap |
-| `dim` | int | `-1` | Time dimension to segment |
-
-### Returns
-
-Tensor with a new trailing dimension: `(..., n_windows, window_size)`.
-
-### Example
-
-```python
-import scitex as stx
-
-xx, tt, fs = stx.dsp.demo_sig()
-# xx.shape: (8, 19, 2048)
-
-# Non-overlapping 256-sample windows
-segments = stx.dsp.to_segments(xx, window_size=256)
-# segments.shape: (8, 19, n_windows, 256)
-
-# 50% overlapping windows
-segments_50 = stx.dsp.to_segments(xx, window_size=256, overlap_factor=2)
-```
-
-## stx.dsp.crop
-
-NumPy-based signal cropping into windows. More flexible than `to_segments`: works on any axis and returns an optional time array.
-
-```python
-cropped = stx.dsp.crop(sig_2d, window_length, overlap_factor=0.0, axis=-1, time=None)
-# or
-cropped, cropped_times = stx.dsp.crop(sig_2d, window_length, overlap_factor=0.5, time=tt)
-```
-
-### Parameters
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `sig_2d` | ndarray | required | Signal array, any dimensionality |
-| `window_length` | int | required | Window length in samples |
-| `overlap_factor` | float | `0.0` | Fraction overlap between windows (0.0 = no overlap, 0.5 = 50%) |
-| `axis` | int | `-1` | Axis to crop along |
-| `time` | ndarray or None | `None` | Time vector matching length along `axis` |
-
-### Returns
-
-- `cropped_windows`: shape `(n_windows, *original_shape_with_window_axis)`
-- If `time` is given: `(cropped_windows, cropped_times)` where `cropped_times.shape = (n_windows, window_length)`
-
-### Example
-
-```python
-import scitex as stx
-import numpy as np
-
-FS = 128
-sig2d = np.random.rand(19, FS * 13) # 19 channels, 13 seconds
-time = np.arange(sig2d.shape[-1]) / FS
-
-window_pts = FS * 2 # 2-second windows
-
-# 50% overlapping crop
-xx, tt = stx.dsp.crop(sig2d, window_pts, overlap_factor=0.5, time=time)
-# xx.shape: (n_windows, 19, 256)
-# tt.shape: (n_windows, 256)
-
-# No time output
-xx = stx.dsp.crop(sig2d, window_pts, overlap_factor=0.5)
-```
-
-## stx.dsp.to_sktime_df
-
-Convert a 3D array to sktime-compatible DataFrame format (nested DataFrames).
-
-```python
-df = stx.dsp.to_sktime_df(arr)
-```
-
-### Parameters
-
-| Parameter | Type | Description |
-|-----------|------|-------------|
-| `arr` | ndarray | Shape `(n_samples, seq_len, n_channels)` — note axis order |
-
-### Returns
-
-pandas `DataFrame` with one column (`dim_0`), where each cell contains a `pd.Series` of all channels for that sample.
-
-### Example
-
-```python
-import scitex as stx
-import numpy as np
-
-arr = np.random.randn(100, 256, 4) # 100 samples, 256 timepoints, 4 channels
-sktime_df = stx.dsp.to_sktime_df(arr)
-# sktime_df.shape: (100, 1)
-# sktime_df.iloc[0, 0]: Series with channel_0, channel_1, channel_2, channel_3
-```
-
-## Comparison: `crop` vs `to_segments`
-
-| | `crop` | `to_segments` |
-|-|--------|---------------|
-| Backend | NumPy | PyTorch |
-| Time vector | Supported | Not supported |
-| Overlap spec | Float fraction (0.0–1.0) | Integer factor |
-| Input dims | Any | `dim` parameter |
-| Decorator | None | `@torch_fn` |
diff --git a/src/scitex/dsp/_skills/spectral.md b/src/scitex/dsp/_skills/spectral.md
deleted file mode 100644
index 4e2a6f6cf..000000000
--- a/src/scitex/dsp/_skills/spectral.md
+++ /dev/null
@@ -1,94 +0,0 @@
----
-description: Power spectral density and per-band average power for multi-channel signals.
----
-
-# stx.dsp — Spectral Analysis
-
-Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_psd.py`
-
-## psd
-
-```python
-psd_vals, freqs = stx.dsp.psd(x, fs, prob=False, dim=-1)
-```
-
-Computes the power spectral density using the PyTorch `PSD` module from `scitex.nn._PSD`.
-
-### Parameters
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `x` | ndarray / Tensor | required | Input signal, shape `(batch, chs, time)` |
-| `fs` | float | required | Sampling frequency in Hz |
-| `prob` | bool | `False` | If `True`, normalize PSD to sum to 1 (probability distribution) |
-| `dim` | int | `-1` | Time dimension |
-
-### Returns
-
-- `psd_vals`: power spectrum, shape `(batch, chs, n_freqs)`, log-scaled
-- `freqs`: frequency axis array, shape `(n_freqs,)`, in Hz
-
-Requires `torch`.
-
-## band_powers
-
-```python
-avg_powers = stx.dsp.band_powers(self, psd)
-```
-
-Computes average power within specified frequency bands from an existing PSD.
-
-Note: `band_powers` as exposed in `__init__.py` is a lower-level function that requires `self` (a PSD instance) and a pre-computed `psd` tensor. It is typically used internally or via the `PSD` class directly.
-
-## Built-in Frequency Bands
-
-Predefined bands are available in `stx.dsp.params.BANDS`:
-
-```python
-import scitex as stx
-
-print(stx.dsp.params.BANDS)
-# delta theta lalpha halpha beta gamma
-# low_hz 0.5 4.0 8.0 10.0 13.0 32.0
-# high_hz 4.0 8.0 10.0 13.0 32.0 75.0
-```
-
-## Examples
-
-```python
-import scitex as stx
-
-xx, tt, fs = stx.dsp.demo_sig(sig_type="chirp", fs=512, t_sec=4)
-
-# Compute PSD
-psd_vals, freqs = stx.dsp.psd(xx, fs)
-# psd_vals.shape: (8, 19, n_freqs)
-# freqs.shape: (n_freqs,)
-
-# Normalized PSD (sums to 1 per channel)
-psd_prob, freqs = stx.dsp.psd(xx, fs, prob=True)
-
-# Plot PSD for first batch/channel
-import matplotlib.pyplot as plt
-fig, ax = plt.subplots()
-ax.plot(freqs, psd_vals[0, 0])
-ax.set_xlabel("Frequency [Hz]")
-ax.set_ylabel("log(Power [uV^2 / Hz])")
-```
-
-## Full pipeline example
-
-```python
-import scitex as stx
-import numpy as np
-
-xx, tt, fs = stx.dsp.demo_sig(fs=512, t_sec=2)
-
-# Filter to gamma band first, then compute PSD
-gamma = stx.dsp.filt.bandpass(xx, fs, np.array([[32, 75]]))
-psd_gamma, freqs = stx.dsp.psd(gamma, fs)
-
-# Check which frequency has peak power in first batch/channel
-peak_idx = psd_gamma[0, 0].argmax()
-print(f"Peak frequency: {freqs[peak_idx]:.1f} Hz")
-```
diff --git a/src/scitex/dsp/_skills/utils.md b/src/scitex/dsp/_skills/utils.md
deleted file mode 100644
index 184329525..000000000
--- a/src/scitex/dsp/_skills/utils.md
+++ /dev/null
@@ -1,175 +0,0 @@
----
-description: Internal helpers — zero-padding, FIR filter design, differentiable bandpass filter banks.
----
-
-# stx.dsp.utils — Utilities
-
-Source directory: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/utils/`
-
-These utilities are used internally by the higher-level DSP functions but can also be called directly.
-
-## Zero-padding
-
-Source: `utils/_zero_pad.py`
-
-```python
-from scitex.dsp.utils import zero_pad, _zero_pad_1d
-```
-
-### `_zero_pad_1d(x, target_length)`
-
-Zero-pad a 1D tensor to a target length, padding symmetrically.
-
-```python
-from scitex.dsp.utils import _zero_pad_1d
-import torch
-
-x = torch.tensor([1.0, 2.0, 3.0])
-padded = _zero_pad_1d(x, target_length=7)
-# tensor([0., 1., 2., 3., 0., 0., 0.]) — 2 left, 2 right
-```
-
-### `zero_pad(xs, dim=0)`
-
-Zero-pad a list of variable-length tensors/arrays to the same length and stack them.
-
-```python
-from scitex.dsp.utils import zero_pad
-import torch
-
-xs = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0, 5.0])]
-stacked = zero_pad(xs, dim=0)
-# tensor([[1., 2., 0.],
-# [3., 4., 5.]])
-```
-
-Accepts NumPy arrays, converts them to tensors automatically.
-
-## FIR filter design
-
-Source: `utils/filter.py`
-
-```python
-from scitex.dsp.utils.filter import design_filter, plot_filter_responses
-```
-
-### `design_filter(sig_len, fs, low_hz=None, high_hz=None, cycle=3, is_bandstop=False)`
-
-Design an FIR filter using `scipy.signal.firwin` with Hamming window. Returns filter coefficient array.
-
-Decorated with `@numpy_fn` (converts tensors to arrays automatically).
-
-| Parameter | Description |
-|-----------|-------------|
-| `sig_len` | Signal length (determines maximum filter order) |
-| `fs` | Sampling frequency |
-| `low_hz` | Low cutoff (omit for highpass) |
-| `high_hz` | High cutoff (omit for lowpass) |
-| `cycle` | Number of cycles at lowest frequency; determines filter order |
-| `is_bandstop` | `True` for bandstop when both `low_hz` and `high_hz` given |
-
-Filter type selection:
-- `low_hz` only → lowpass
-- `high_hz` only → highpass
-- both → bandpass (default) or bandstop (`is_bandstop=True`)
-
-```python
-from scitex.dsp.utils.filter import design_filter, plot_filter_responses
-import scitex as stx
-
-xx, tt, fs = stx.dsp.demo_sig()
-seq_len = xx.shape[-1]
-
-bp = design_filter(seq_len, fs, low_hz=30, high_hz=70)
-lp = design_filter(seq_len, fs, low_hz=30)
-hp = design_filter(seq_len, fs, high_hz=70)
-bs = design_filter(seq_len, fs, low_hz=30, high_hz=70, is_bandstop=True)
-```
-
-### `plot_filter_responses(filter, fs, worN=8000, title=None)`
-
-Plot impulse response and frequency response of an FIR filter. Returns a matplotlib `Figure`.
-
-```python
-fig = plot_filter_responses(bp, fs, title="Bandpass 30-70 Hz")
-```
-
-## Differentiable bandpass filter banks (for gradient-based optimization)
-
-Source: `utils/_differential_bandpass_filters.py`
-
-```python
-from scitex.dsp.utils import init_bandpass_filters, build_bandpass_filters
-```
-
-These build learnable PAC filter banks whose center frequencies can be optimized via backpropagation.
-
-Requires `torchaudio.prototype.functional.sinc_impulse_response`.
-
-### `init_bandpass_filters(sig_len, fs, pha_low_hz, pha_high_hz, pha_n_bands, amp_low_hz, amp_high_hz, amp_n_bands, cycle)`
-
-Initialize a filter bank with learnable `pha_mids` and `amp_mids` parameters.
-
-```python
-from scitex.dsp.utils import init_bandpass_filters
-import scitex as stx
-
-xx, tt, fs = stx.dsp.demo_sig(fs=1024)
-filters, pha_mids, amp_mids = init_bandpass_filters(
- sig_len=xx.shape[-1],
- fs=fs,
- pha_low_hz=2, pha_high_hz=20, pha_n_bands=30,
- amp_low_hz=60, amp_high_hz=160, amp_n_bands=50,
-)
-# filters: stacked impulse responses shape (pha_n_bands + amp_n_bands, filter_len)
-# pha_mids, amp_mids: nn.Parameter — gradients flow through these
-
-# Verify gradients work
-filters.sum().backward()
-print(pha_mids.grad) # not None
-```
-
-### `build_bandpass_filters(sig_len, fs, pha_mids, amp_mids, cycle)`
-
-Rebuild filter bank from (updated) center frequency parameters. Call this in the forward pass after optimizer.step() to apply learned frequencies.
-
-```python
-from scitex.dsp.utils import build_bandpass_filters
-
-# After optimizer step:
-new_filters = build_bandpass_filters(sig_len, fs, pha_mids, amp_mids, cycle=3)
-```
-
-## ensure_3d
-
-Source: `utils/_ensure_3d.py` (also available as `stx.dsp.ensure_3d`)
-
-```python
-x_3d = stx.dsp.ensure_3d(x)
-```
-
-Promotes 1D `(time,)` or 2D `(batch, time)` tensors to 3D `(batch, chs, time)` for compatibility with all DSP functions.
-
-```python
-import torch
-import scitex as stx
-
-x1d = torch.randn(512)
-x2d = torch.randn(8, 512)
-x3d = torch.randn(8, 19, 512)
-
-stx.dsp.ensure_3d(x1d).shape # (1, 1, 512)
-stx.dsp.ensure_3d(x2d).shape # (8, 1, 512)
-stx.dsp.ensure_3d(x3d).shape # (8, 19, 512) — unchanged
-```
-
-## stx.dsp.time
-
-Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_time.py`
-
-Generate a time vector using `stx.gen.float_linspace`.
-
-```python
-t = stx.dsp.time(start_sec=0, end_sec=5, fs=256)
-# Returns array of (end_sec - start_sec) * fs evenly spaced values
-```
diff --git a/src/scitex/dsp/_skills/wavelet.md b/src/scitex/dsp/_skills/wavelet.md
deleted file mode 100644
index 45fb4e51a..000000000
--- a/src/scitex/dsp/_skills/wavelet.md
+++ /dev/null
@@ -1,119 +0,0 @@
----
-description: Continuous wavelet transform returning time-frequency phase and amplitude.
----
-
-# stx.dsp.wavelet — Wavelet Transform
-
-Source: `/home/ywatanabe/proj/scitex-python/src/scitex/dsp/_wavelet.py`
-
-## Signature
-
-```python
-pha, amp, freqs = stx.dsp.wavelet(
- x,
- fs,
- freq_scale="linear",
- out_scale="linear",
- device="cuda",
- batch_size=32,
-)
-```
-
-Computes a continuous wavelet transform (CWT) using the `Wavelet` module from `scitex.nn._Wavelet`. The function is decorated with both `@signal_fn` and `@batch_fn`, so it handles type conversion and automatic batch splitting when input is too large for GPU memory.
-
-Requires `torch`.
-
-### Parameters
-
-| Parameter | Type | Default | Description |
-|-----------|------|---------|-------------|
-| `x` | ndarray / Tensor | required | Signal, shape `(batch, chs, time)` |
-| `fs` | float | required | Sampling frequency in Hz |
-| `freq_scale` | str | `"linear"` | Frequency axis spacing: `"linear"` or `"log"` |
-| `out_scale` | str | `"linear"` | Output amplitude scale: `"linear"` or `"log"` |
-| `device` | str | `"cuda"` | PyTorch device |
-| `batch_size` | int | `32` | Batch size for the `@batch_fn` wrapper |
-
-### Returns
-
-- `pha`: instantaneous phase, shape `(batch, chs, n_freqs, time)`
-- `amp`: amplitude envelope, shape `(batch, chs, n_freqs, time)`
-- `freqs`: frequency axis, shape `(batch, chs, n_freqs)` — take `freqs[0, 0]` for the 1D array
-
-## Examples
-
-```python
-import scitex as stx
-import numpy as np
-
-xx, tt, fs = stx.dsp.demo_sig(
- sig_type="chirp", batch_size=4, n_chs=2, fs=512, t_sec=4
-)
-
-# Compute wavelet transform
-pha, amp, freqs = stx.dsp.wavelet(xx, fs, device="cuda")
-
-# freqs is per-batch/channel; take the 1D version
-freqs_1d = freqs[0, 0] # shape: (n_freqs,)
-
-print(f"pha shape: {pha.shape}") # (4, 2, n_freqs, 2048)
-print(f"amp shape: {amp.shape}") # (4, 2, n_freqs, 2048)
-print(f"freqs: {freqs_1d}")
-```
-
-### Log-scale amplitude output
-
-```python
-pha, amp_log, freqs = stx.dsp.wavelet(xx, fs, out_scale="log")
-# amp_log contains log(amplitude + 1e-5) — NaN-safe log scaling
-```
-
-### Spectrogram plot
-
-```python
-import matplotlib.pyplot as plt
-
-pha, amp, freqs = stx.dsp.wavelet(xx, fs)
-freqs_1d = freqs[0, 0].cpu().numpy()
-i_batch, i_ch = 0, 0
-
-fig, axes = plt.subplots(3, 1, figsize=(10, 8))
-
-# Raw signal
-axes[0].plot(tt, xx[i_batch, i_ch])
-axes[0].set_ylabel("Amplitude")
-axes[0].set_title("Signal")
-
-# Amplitude spectrogram
-log_amp = (amp[i_batch, i_ch] + 1e-5).log().cpu().numpy()
-axes[1].imshow(log_amp.T, aspect="auto", origin="lower")
-axes[1].set_ylabel("Frequency [Hz]")
-axes[1].set_title("Wavelet Amplitude")
-
-# Phase spectrogram
-phase_np = pha[i_batch, i_ch].cpu().numpy()
-axes[2].imshow(phase_np.T, aspect="auto", origin="lower")
-axes[2].set_ylabel("Frequency [Hz]")
-axes[2].set_title("Wavelet Phase [rad]")
-axes[2].set_xlabel("Time [s]")
-```
-
-### Log frequency scale
-
-```python
-pha, amp, freqs = stx.dsp.wavelet(xx, fs, freq_scale="log")
-# freqs are logarithmically spaced — more resolution at low frequencies
-```
-
-## PAC segments
-
-When input has a segment dimension `(batch, chs, n_segments, time)`, extract one segment first:
-
-```python
-xx, tt, fs = stx.dsp.demo_sig(sig_type="pac", n_segments=20, fs=512, t_sec=4)
-# xx.shape: (8, 19, 20, 2048) — has segment dim
-
-i_segment = 0
-xx_seg = xx[:, :, i_segment, :] # (8, 19, 2048)
-pha, amp, freqs = stx.dsp.wavelet(xx_seg, fs)
-```
diff --git a/src/scitex/dsp/_time.py b/src/scitex/dsp/_time.py
deleted file mode 100755
index 659d143bf..000000000
--- a/src/scitex/dsp/_time.py
+++ /dev/null
@@ -1,40 +0,0 @@
-#!./env/bin/python3
-# -*- coding: utf-8 -*-
-# Time-stamp: "2024-06-30 12:11:01 (ywatanabe)"
-# /mnt/ssd/ripple-wm-code/scripts/externals/scitex/src/scitex/dsp/_time.py
-
-
-import numpy as np
-
-import scitex
-
-
-def time(start_sec, end_sec, fs):
- # return np.linspace(start_sec, end_sec, (end_sec - start_sec) * fs)
- return scitex.gen.float_linspace(start_sec, end_sec, (end_sec - start_sec) * fs)
-
-
-def main():
- out = time(10, 15, 256)
- print(out)
-
-
-if __name__ == "__main__":
- import sys
-
- import matplotlib.pyplot as plt
-
- # # Argument Parser
- # import argparse
- # parser = argparse.ArgumentParser(description='')
- # parser.add_argument('--var', '-v', type=int, default=1, help='')
- # parser.add_argument('--flag', '-f', action='store_true', default=False, help='')
- # args = parser.parse_args()
- # Main
- CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(
- sys, plt, verbose=False
- )
- main()
- scitex.session.close(CONFIG, verbose=False, notify=False)
-
-# EOF
diff --git a/src/scitex/dsp/_transform.py b/src/scitex/dsp/_transform.py
deleted file mode 100755
index caeda6aca..000000000
--- a/src/scitex/dsp/_transform.py
+++ /dev/null
@@ -1,80 +0,0 @@
-#!/usr/bin/env python3
-# Time-stamp: "2024-04-08 12:41:59 (ywatanabe)"#!/usr/bin/env python3
-
-
-import numpy as np
-import pandas as pd
-
-try:
- import torch
-
- TORCH_AVAILABLE = True
-except ImportError:
- TORCH_AVAILABLE = False
- torch = None
-
-from scitex.decorators import torch_fn
-
-
-def _check_torch():
- if not TORCH_AVAILABLE:
- raise ImportError(
- "PyTorch is not installed. Please install with: pip install torch"
- )
-
-
-def to_sktime_df(arr):
- """
- Convert a 3D numpy array into a DataFrame suitable for sktime.
-
- Parameters:
- arr (numpy.ndarray): A 3D numpy array with shape (n_samples, n_channels, seq_len)
-
- Returns:
- pandas.DataFrame: A DataFrame in sktime format
- """
- if len(arr.shape) != 3:
- raise ValueError("Input data must be a 3D array")
-
- n_samples, seq_len, n_channels = arr.shape
-
- # Initialize an empty DataFrame for sktime format
- sktime_df = pd.DataFrame(index=range(n_samples), columns=["dim_0"])
-
- # Iterate over each sample
- for i in range(n_samples):
- # Combine all channels into a single cell
- combined_series = pd.Series(
- {f"channel_{j}": pd.Series(arr[i, :, j]) for j in range(n_channels)}
- )
- sktime_df.iloc[i, 0] = combined_series
-
- return sktime_df
-
-
-@torch_fn
-def to_segments(x, window_size, overlap_factor=1, dim=-1):
- stride = window_size // overlap_factor
- num_windows = (x.size(dim) - window_size) // stride + 1
- windows = x.unfold(dim, window_size, stride)
- return windows
-
-
-if __name__ == "__main__":
- import scitex
-
- x, t, f = scitex.dsp.demo_sig()
-
- y = to_segments(x, 256)
-
- x = 100 * np.random.rand(16, 160, 1000)
- print(_normalize_time(x))
-
- x = torch.randn(16, 160, 1000)
- print(_normalize_time(x))
-
- x = torch.randn(16, 160, 1000).cuda()
- print(_normalize_time(x))
-
-
-# EOF
diff --git a/src/scitex/dsp/_wavelet.py b/src/scitex/dsp/_wavelet.py
deleted file mode 100755
index 521111fe4..000000000
--- a/src/scitex/dsp/_wavelet.py
+++ /dev/null
@@ -1,212 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-# Time-stamp: "2024-11-04 02:12:00 (ywatanabe)"
-# File: ./scitex_repo/src/scitex/dsp/_wavelet.py
-
-"""scitex.dsp.wavelet function"""
-
-import scitex
-from scitex.decorators import batch_fn, signal_fn
-from scitex.nn._Wavelet import Wavelet
-
-
-# Functions
-@signal_fn
-@batch_fn
-def wavelet(
- x,
- fs,
- freq_scale="linear",
- out_scale="linear",
- device="cuda",
- batch_size=32,
-):
- m = Wavelet(fs, freq_scale=freq_scale, out_scale="linear").to(device).eval()
- pha, amp, freqs = m(x.to(device))
-
- if out_scale == "log":
- amp = (amp + 1e-5).log()
- if amp.isnan().any():
- print("NaN is detected while taking the lograrithm of amplitude.")
-
- return pha, amp, freqs
-
-
-# @signal_fn
-# def wavelet(
-# x,
-# fs,
-# freq_scale="linear",
-# out_scale="linear",
-# device="cuda",
-# batch_size=32,
-# ):
-# @signal_fn
-# def _wavelet(
-# x,
-# fs,
-# freq_scale="linear",
-# out_scale="linear",
-# device="cuda",
-# ):
-# m = (
-# Wavelet(fs, freq_scale=freq_scale, out_scale=out_scale)
-# .to(device)
-# .eval()
-# )
-# pha, amp, freqs = m(x.to(device))
-
-# if out_scale == "log":
-# amp = (amp + 1e-5).log()
-# if amp.isnan().any():
-# print(
-# "NaN is detected while taking the lograrithm of amplitude."
-# )
-
-# return pha, amp, freqs
-
-# if len(x) <= batch_size:
-# try:
-# pha, amp, freqs = _wavelet(
-# x,
-# fs,
-# freq_scale=freq_scale,
-# out_scale=out_scale,
-# device=device,
-# )
-# torch.cuda.empty_cache()
-# return pha, amp, freqs
-
-# except Exception as e:
-# print(e)
-# print("\nTrying Batch Mode...")
-
-# n_batches = (len(x) + batch_size - 1) // batch_size
-# device_orig = x.device
-# pha, amp, freqs = [], [], []
-# for i_batch in tqdm(range(n_batches)):
-# start = i_batch * batch_size
-# end = (i_batch + 1) * batch_size
-# _pha, _amp, _freqs = _wavelet(
-# x[start:end],
-# fs,
-# freq_scale=freq_scale,
-# out_scale=out_scale,
-# device=device,
-# )
-# torch.cuda.empty_cache()
-# # to CPU
-# pha.append(_pha.cpu())
-# amp.append(_amp.cpu())
-# freqs.append(_freqs.cpu())
-
-# pha = torch.vstack(pha)
-# amp = torch.vstack(amp)
-# freqs = freqs[0]
-
-# try:
-# pha = pha.to(device_orig)
-# amp = amp.to(device_orig)
-# freqs = freqs.to(device_orig)
-# except Exception as e:
-# print(
-# f"\nError occurred while transferring wavelet outputs back to the original device. Proceeding with CPU tensor. \n\n({e})"
-# )
-
-# sleep(0.5)
-# torch.cuda.empty_cache()
-# return pha, amp, freqs
-
-
-if __name__ == "__main__":
- import sys
-
- import matplotlib.pyplot as plt
- import numpy as np
-
- # Start
- CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt, agg=True)
-
- # Parameters
- FS = 512
- SIG_TYPE = "chirp"
- T_SEC = 4
-
- # Demo signal
- xx, tt, fs = scitex.dsp.demo_sig(
- batch_size=64,
- n_chs=19,
- n_segments=2,
- t_sec=T_SEC,
- fs=FS,
- sig_type=SIG_TYPE,
- )
-
- if SIG_TYPE in ["tensorpac", "pac"]:
- i_segment = 0
- xx = xx[:, :, i_segment, :]
-
- # Main
- pha, amp, freqs = wavelet(xx, fs, device="cuda")
- freqs = freqs[0, 0]
-
- # Plots
- i_batch, i_ch = 0, 0
- fig, axes = scitex.plt.subplots(nrows=3)
-
- # # Time vector for x-axis extents
- # time_extent = [tt.min(), tt.max()]
-
- # Trace
- axes[0].plot(tt, xx[i_batch, i_ch], label=SIG_TYPE)
- axes[0].set_ylabel("Amplitude [?V]")
- axes[0].legend(loc="upper left")
- axes[0].set_title("Signal")
-
- # Amplitude
- # extent = [time_extent[0], time_extent[1], freqs.min(), freqs.max()]
- axes[1].imshow2d(
- np.log(amp[i_batch, i_ch] + 1e-5).T,
- cbar_label="Log(amplitude [?V]) [a.u.]",
- aspect="auto",
- # extent=extent,
- # origin="lower",
- )
- axes[1] = scitex.plt.ax.set_ticks(axes[1], x_ticks=tt, y_ticks=freqs)
- axes[1].set_ylabel("Frequency [Hz]")
- axes[1].set_title("Amplitude")
-
- # Phase
- axes[2].imshow2d(
- pha[i_batch, i_ch].T,
- cbar_label="Phase [rad]",
- aspect="auto",
- # extent=extent,
- # origin="lower",
- )
- axes[2] = scitex.plt.ax.set_ticks(axes[2], x_ticks=tt, y_ticks=freqs)
- axes[2].set_ylabel("Frequency [Hz]")
- axes[2].set_title("Phase")
-
- fig.suptitle("Wavelet Transformation")
- fig.supxlabel("Time [s]")
-
- for ax in axes:
- ax = scitex.plt.ax.set_n_ticks(ax)
- # ax.set_xlim(time_extent[0], time_extent[1])
-
- fig.tight_layout(rect=[0, 0.03, 1, 0.95])
-
- scitex.io.save(fig, "wavelet.png")
-
- # Close
- scitex.session.close(CONFIG)
-
-# EOF
-
-"""
-/home/ywatanabe/proj/entrance/scitex/dsp/_wavelet.py
-"""
-
-
-# EOF
diff --git a/src/scitex/dsp/add_noise.py b/src/scitex/dsp/add_noise.py
deleted file mode 100755
index 5de59e9e4..000000000
--- a/src/scitex/dsp/add_noise.py
+++ /dev/null
@@ -1,128 +0,0 @@
-#!/usr/bin/env python3
-# Time-stamp: "ywatanabe (2024-11-02 23:09:49)"
-# File: ./scitex_repo/src/scitex/dsp/add_noise.py
-
-try:
- import torch
-
- TORCH_AVAILABLE = True
-except ImportError:
- TORCH_AVAILABLE = False
- torch = None
-
-from scitex.decorators import signal_fn
-
-
-def _check_torch():
- if not TORCH_AVAILABLE:
- raise ImportError(
- "PyTorch is not installed. Please install with: pip install torch"
- )
-
-
-def _uniform(shape, amp=1.0):
- _check_torch()
- a, b = -amp, amp
- return -amp + (2 * amp) * torch.rand(shape)
-
-
-@signal_fn
-def gauss(x, amp=1.0):
- noise = amp * torch.randn(x.shape)
- return x + noise.to(x.device)
-
-
-@signal_fn
-def white(x, amp=1.0):
- return x + _uniform(x.shape, amp=amp).to(x.device)
-
-
-@signal_fn
-def pink(x, amp=1.0, dim=-1):
- """
- Adds pink noise to a given tensor along a specified dimension.
-
- Parameters:
- - x (torch.Tensor): The input tensor to which pink noise will be added.
- - amp (float, optional): The amplitude of the pink noise. Defaults to 1.0.
- - dim (int, optional): The dimension along which to add pink noise. Defaults to -1.
-
- Returns:
- - torch.Tensor: The input tensor with added pink noise.
- """
- cols = x.size(dim)
- noise = torch.randn(cols, dtype=x.dtype, device=x.device)
- noise = torch.fft.rfft(noise)
- indices = torch.arange(1, noise.size(0), dtype=x.dtype, device=x.device)
- noise[1:] /= torch.sqrt(indices)
- noise = torch.fft.irfft(noise, n=cols)
- noise = noise - noise.mean()
- noise_amp = torch.sqrt(torch.mean(noise**2))
- noise = noise * (amp / noise_amp)
- return x + noise.to(x.device)
-
-
-@signal_fn
-def brown(x, amp=1.0, dim=-1):
- from scitex.dsp import norm
-
- noise = _uniform(x.shape, amp=amp)
- noise = torch.cumsum(noise, dim=dim)
- noise = norm.minmax(noise, amp=amp, dim=dim)
- return x + noise.to(x.device)
-
-
-if __name__ == "__main__":
- import sys
-
- import matplotlib.pyplot as plt
-
- import scitex
-
- # Start
- CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt)
-
- # Parameters
- T_SEC = 1
- FS = 128
-
- # Demo signal
- xx, tt, fs = scitex.dsp.demo_sig(t_sec=T_SEC, fs=FS)
-
- funcs = {
- "orig": lambda x: x,
- "gauss": gauss,
- "white": white,
- "pink": pink,
- "brown": brown,
- }
-
- # Plots
- fig, axes = scitex.plt.subplots(nrows=len(funcs), ncols=2, sharex=True, sharey=True)
- count = 0
- for (k, fn), axes_row in zip(funcs.items(), axes):
- for ax in axes_row:
- if count % 2 == 0:
- ax.plot(tt, fn(xx)[0, 0], label=k, c="blue")
- else:
- ax.plot(tt, (fn(xx) - xx)[0, 0], label=f"{k} - orig", c="red")
- count += 1
- ax.legend(loc="upper right")
-
- fig.supxlabel("Time [s]")
- fig.supylabel("Amplitude [?V]")
- axes[0, 0].set_title("Signal + Noise")
- axes[0, 1].set_title("Noise")
-
- scitex.io.save(fig, "traces.png")
-
- # Close
- scitex.session.close(CONFIG)
-
-# EOF
-
-"""
-/home/ywatanabe/proj/entrance/scitex/dsp/add_noise.py
-"""
-
-# EOF
diff --git a/src/scitex/dsp/audio.md b/src/scitex/dsp/audio.md
deleted file mode 100755
index 9ea6f8214..000000000
--- a/src/scitex/dsp/audio.md
+++ /dev/null
@@ -1,22 +0,0 @@
-
-
-
-## Audio Support
-``` bash
-sudo apt remove python3-pyaudio
-sudo apt-get install -y libasound2-dev portaudio19-dev libportaudio2
-pip install --no-cache-dir pyaudio
-
-sudo apt-get install -y alsa-utils
-speaker-test -t sine -f 440
-
-sudo apt-get update
-sudo apt-get install -y pulseaudio
-sudo usermod -aG audio $USER
-pulseaudio --start
-
-```
diff --git a/src/scitex/dsp/example.py b/src/scitex/dsp/example.py
deleted file mode 100755
index f61ad6df9..000000000
--- a/src/scitex/dsp/example.py
+++ /dev/null
@@ -1,261 +0,0 @@
-#!/usr/bin/env python3
-# Time-stamp: "2024-04-06 01:36:18 (ywatanabe)"
-
-import matplotlib
-
-matplotlib.use("Agg")
-import matplotlib.pyplot as plt
-import pandas as pd
-
-import scitex
-
-# Module-level constants (defaults for example functions)
-TGT_FS = 512
-LOW_HZ = 20
-HIGH_HZ = 50
-SIGMA = 10
-
-# Default color cycle
-CC = {"blue": "#1f77b4", "red": "#d62728", "green": "#2ca02c"}
-
-
-# Functions
-def calc_norm_resample_filt_hilbert(xx, tt, fs, sig_type, verbose=True):
- sigs = {"index": ("signal", "time", "fs")} # Collector
-
- if sig_type == "tensorpac":
- xx = xx[:, :, 0]
-
- sigs["orig"] = (xx, tt, fs)
-
- # Normalization
- sigs["z_normed"] = (scitex.dsp.norm.z(xx), tt, fs)
- sigs["minmax_normed"] = (scitex.dsp.norm.minmax(xx), tt, fs)
-
- # Resampling
- resampled_xx = scitex.dsp.resample(xx, fs, TGT_FS)
- # Create proper time vector for resampled signal
- import numpy as np
-
- resampled_tt = np.linspace(tt[0], tt[-1], resampled_xx.shape[-1])
- sigs["resampled"] = (resampled_xx, resampled_tt, TGT_FS)
-
- # Noise injection
- sigs["gaussian_noise_added"] = (scitex.dsp.add_noise.gauss(xx), tt, fs)
- sigs["white_noise_added"] = (scitex.dsp.add_noise.white(xx), tt, fs)
- sigs["pink_noise_added"] = (scitex.dsp.add_noise.pink(xx), tt, fs)
- sigs["brown_noise_added"] = (scitex.dsp.add_noise.brown(xx), tt, fs)
-
- # Filtering (bands format is [[low_hz, high_hz]])
- bands = [[LOW_HZ, HIGH_HZ]]
- sigs[f"bandpass_filted ({LOW_HZ} - {HIGH_HZ} Hz)"] = (
- scitex.dsp.filt.bandpass(xx, fs, bands),
- tt,
- fs,
- )
-
- sigs[f"bandstop_filted ({LOW_HZ} - {HIGH_HZ} Hz)"] = (
- scitex.dsp.filt.bandstop(xx, fs, bands),
- tt,
- fs,
- )
- sigs[f"bandstop_gauss (sigma = {SIGMA})"] = (
- scitex.dsp.filt.gauss(xx, sigma=SIGMA),
- tt,
- fs,
- )
-
- # Hilbert Transformation
- pha, amp = scitex.dsp.hilbert(xx)
- sigs["hilbert_amp"] = (amp, tt, fs)
- sigs["hilbert_pha"] = (pha, tt, fs)
-
- sigs = pd.DataFrame(sigs).set_index("index")
-
- if verbose:
- print(sigs.index)
- print(sigs.columns)
-
- return sigs
-
-
-def plot_signals(plt, sigs, sig_type):
- fig, axes = plt.subplots(nrows=len(sigs.columns), sharex=True)
-
- i_batch = 0
- i_ch = 0
- for ax, (i_col, col) in zip(axes, enumerate(sigs.columns)):
- if col == "hilbert_amp": # add the original signal to the ax
- _col = "orig"
- (
- _xx,
- _tt,
- _fs,
- ) = sigs[_col]
- ax.plot(_tt, _xx[i_batch, i_ch], label=_col, c=CC["blue"])
-
- # Main
- xx, tt, fs = sigs[col]
- # if sig_type == "tensorpac":
- # xx = xx[:, :, 0]
-
- # Handle potential shape mismatches from filter operations
- signal = xx[i_batch, i_ch]
- if hasattr(signal, "squeeze"):
- signal = signal.squeeze()
- if hasattr(signal, "numpy"):
- signal = signal.numpy()
-
- ax.plot(
- tt,
- signal,
- label=col,
- c=CC["red"] if col == "hilbert_amp" else CC["blue"],
- )
-
- # Adjustments
- ax.legend(loc="upper left")
- ax.set_xlim(tt[0], tt[-1])
-
- ax = scitex.plt.ax.set_n_ticks(ax)
-
- fig.supxlabel("Time [s]")
- fig.supylabel("Voltage")
- fig.suptitle(sig_type)
- return fig
-
-
-def plot_wavelet(plt, sigs, sig_col, sig_type):
- xx, tt, fs = sigs[sig_col]
- # if sig_type == "tensorpac":
- # xx = xx[:, :, 0]
-
- # Wavelet Transformation
- wavelet_coef, ff_ww = scitex.dsp.wavelet(xx, fs)
-
- i_batch = 0
- i_ch = 0
-
- # Main
- fig, axes = plt.subplots(nrows=2, sharex=True)
- # Signal
- axes[0].plot(
- tt,
- xx[i_batch, i_ch],
- label=sig_col,
- c=CC["blue"],
- )
- # Adjusts
- axes[0].legend(loc="upper left")
- axes[0].set_xlim(tt[0], tt[-1])
- axes[0].set_ylabel("Voltage")
- axes[0] = scitex.plt.ax.set_n_ticks(axes[0])
-
- # Wavelet Spectrogram
- axes[1].imshow(
- wavelet_coef[i_batch, i_ch],
- aspect="auto",
- extent=[tt[0], tt[-1], 512, 1],
- label="wavelet_coefficient",
- )
- # axes[1].set_xlabel("Time [s]")
- axes[1].set_ylabel("Frequency [Hz]")
- # axes[1].legend(loc="upper left")
- axes[1].invert_yaxis()
-
- fig.supxlabel("Time [s]")
- fig.suptitle(sig_type)
-
- return fig
-
-
-def plot_psd(plt, sigs, sig_col, sig_type):
- xx, tt, fs = sigs[sig_col]
-
- # if sig_type == "tensorpac":
- # xx = xx[:, :, 0]
-
- # Power Spetrum Density
- psd, ff_pp = scitex.dsp.psd(xx, fs)
-
- # Main
- i_batch = 0
- i_ch = 0
- fig, axes = plt.subplots(nrows=2, sharex=False)
-
- # Signal
- axes[0].plot(
- tt,
- xx[i_batch, i_ch],
- label=sig_col,
- c=CC["blue"],
- )
- # Adjustments
- axes[0].legend(loc="upper left")
- axes[0].set_xlim(tt[0], tt[-1])
- axes[0].set_xlabel("Time [s]")
- axes[0].set_ylabel("Voltage")
- axes[0] = scitex.plt.ax.set_n_ticks(axes[0])
-
- # PSD
- axes[1].plot(ff_pp, psd[i_batch, i_ch], label="PSD")
- axes[1].set_yscale("log")
- axes[1].set_ylabel("Power [uV^2 / Hz]")
- axes[1].set_xlabel("Frequency [Hz]")
-
- fig.suptitle(sig_type)
-
- return fig
-
-
-if __name__ == "__main__":
- # Parameters
- T_SEC = 4
- SIG_TYPES = [
- # "uniform",
- # "gauss",
- # "periodic",
- # "chirp",
- # "ripple",
- # "meg",
- "tensorpac",
- ]
- SRC_FS = 1024
- TGT_FS = 512
- FREQS_HZ = [10, 30, 100]
- LOW_HZ = 20
- HIGH_HZ = 50
- SIGMA = 10
-
- plt, CC = scitex.plt.configure_mpl(plt, fig_scale=10)
- sdir = "/home/ywatanabe/proj/entrance/scitex/dsp/example/"
-
- for sig_type in SIG_TYPES:
- # Demo Signal
- xx, tt, fs = scitex.dsp.demo_sig(
- t_sec=T_SEC, fs=SRC_FS, freqs_hz=FREQS_HZ, sig_type=sig_type
- )
-
- # Apply calculations on the original signal
- sigs = calc_norm_resample_filt_hilbert(xx, tt, fs, sig_type)
-
- # Plots signals
- fig = plot_signals(plt, sigs, sig_type)
- scitex.io.save(fig, sdir + f"{sig_type}/1_signals.png")
-
- # Plots wavelet coefficients and PSD
- for sig_col in sigs.columns:
- if "hilbert" in sig_col:
- continue
-
- fig = plot_wavelet(plt, sigs, sig_col, sig_type)
- scitex.io.save(fig, sdir + f"{sig_type}/2_wavelet_{sig_col}.png")
-
- fig = plot_psd(plt, sigs, sig_col, sig_type)
- scitex.io.save(fig, sdir + f"{sig_type}/3_psd_{sig_col}.png")
-
- # plt.show()
-
- """
- python ./dsp/example.py
- """
diff --git a/src/scitex/dsp/filt.py b/src/scitex/dsp/filt.py
deleted file mode 100755
index 64289d093..000000000
--- a/src/scitex/dsp/filt.py
+++ /dev/null
@@ -1,164 +0,0 @@
-#!/usr/bin/env python3
-# Time-stamp: "2024-11-04 02:05:47 (ywatanabe)"
-# File: ./scitex_repo/src/scitex/dsp/filt.py
-
-import numpy as np
-
-import scitex
-from scitex.decorators import signal_fn
-
-# No top-level imports from nn module to avoid circular dependency
-# Filters will be imported inside functions when needed
-
-
-@signal_fn
-def gauss(x, sigma, t=None):
- from scitex.nn._Filters import GaussianFilter
-
- return GaussianFilter(sigma)(x, t=t)
-
-
-@signal_fn
-def bandpass(x, fs, bands, t=None):
- import torch
-
- from scitex.nn._Filters import BandPassFilter
-
- # Convert bands to tensor if it's not already
- if not isinstance(bands, torch.Tensor):
- bands = torch.tensor(bands, dtype=torch.float32)
- return BandPassFilter(bands, fs, x.shape[-1])(x, t=t)
-
-
-@signal_fn
-def bandstop(x, fs, bands, t=None):
- import torch
-
- from scitex.nn._Filters import BandStopFilter
-
- # Convert bands to tensor if it's not already
- if not isinstance(bands, torch.Tensor):
- bands = torch.tensor(bands, dtype=torch.float32)
- return BandStopFilter(bands, fs, x.shape[-1])(x, t=t)
-
-
-@signal_fn
-def lowpass(x, fs, cutoffs_hz, t=None):
- from scitex.nn._Filters import LowPassFilter
-
- return LowPassFilter(cutoffs_hz, fs, x.shape[-1])(x, t=t)
-
-
-@signal_fn
-def highpass(x, fs, cutoffs_hz, t=None):
- from scitex.nn._Filters import HighPassFilter
-
- return HighPassFilter(cutoffs_hz, fs, x.shape[-1])(x, t=t)
-
-
-def _custom_print(x):
- print(type(x), x.shape)
-
-
-if __name__ == "__main__":
- import sys
-
- import matplotlib.pyplot as plt
-
- # Start
- CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt)
-
- # Parametes
- T_SEC = 1
- SRC_FS = 1024
- FREQS_HZ = list(np.linspace(0, 500, 10, endpoint=False).astype(int))
- SIG_TYPE = "periodic"
- BANDS = np.vstack([[80, 310]])
- SIGMA = 3
-
- # Demo Signal
- xx, tt, fs = scitex.dsp.demo_sig(
- t_sec=T_SEC,
- fs=SRC_FS,
- freqs_hz=FREQS_HZ,
- sig_type=SIG_TYPE,
- )
-
- # Filtering
- x_bp, t_bp = scitex.dsp.filt.bandpass(xx, fs, BANDS, t=tt)
- x_bs, t_bs = scitex.dsp.filt.bandstop(xx, fs, BANDS, t=tt)
- x_lp, t_lp = scitex.dsp.filt.lowpass(xx, fs, BANDS[:, 0], t=tt)
- x_hp, t_hp = scitex.dsp.filt.highpass(xx, fs, BANDS[:, 1], t=tt)
- x_g, t_g = scitex.dsp.filt.gauss(xx, sigma=SIGMA, t=tt)
- filted = {
- f"Original (Sum of {FREQS_HZ}-Hz signals)": (xx, tt, fs),
- f"Bandpass-filtered ({BANDS[0][0]} - {BANDS[0][1]} Hz)": (
- x_bp,
- t_bp,
- fs,
- ),
- f"Bandstop-filtered ({BANDS[0][0]} - {BANDS[0][1]} Hz)": (
- x_bs,
- t_bs,
- fs,
- ),
- f"Lowpass-filtered ({BANDS[0][0]} Hz)": (x_lp, t_lp, fs),
- f"Highpass-filtered ({BANDS[0][1]} Hz)": (x_hp, t_hp, fs),
- f"Gaussian-filtered (sigma = {SIGMA} SD [point])": (x_g, t_g, fs),
- }
-
- # Plots traces
- fig, axes = plt.subplots(nrows=len(filted), ncols=1, sharex=True, sharey=True)
- i_batch = 0
- i_ch = 0
- i_filt = 0
- for ax, (k, v) in zip(axes, filted.items()):
- _xx, _tt, _fs = v
- if _xx.ndim == 3:
- _xx = _xx[i_batch, i_ch]
- elif _xx.ndim == 4:
- _xx = _xx[i_batch, i_ch, i_filt]
- ax.plot(_tt, _xx, label=k)
- ax.legend(loc="upper left")
-
- fig.suptitle("Filtered")
- fig.supxlabel("Time [s]")
- fig.supylabel("Amplitude")
-
- scitex.io.save(fig, "traces.png")
-
- # Calculates and Plots PSD
- fig, axes = plt.subplots(nrows=len(filted), ncols=1, sharex=True, sharey=True)
- i_batch = 0
- i_ch = 0
- i_filt = 0
- for ax, (k, v) in zip(axes, filted.items()):
- _xx, _tt, _fs = v
-
- _psd, ff = scitex.dsp.psd(_xx, _fs)
- if _psd.ndim == 3:
- _psd = _psd[i_batch, i_ch]
- elif _psd.ndim == 4:
- _psd = _psd[i_batch, i_ch, i_filt]
-
- ax.plot(ff, _psd, label=k)
- ax.legend(loc="upper left")
-
- for bb in np.hstack(BANDS):
- ax.axvline(x=bb, color=CC["grey"], linestyle="--")
-
- fig.suptitle("PSD (power spectrum density) of filtered signals")
- fig.supxlabel("Frequency [Hz]")
- fig.supylabel("log(Power [uV^2 / Hz]) [a.u.]")
- scitex.io.save(fig, "psd.png")
-
- # Close
- scitex.session.close(CONFIG)
-
-# EOF
-
-"""
-/home/ywatanabe/proj/scitex/src/scitex/dsp/filt.py
-"""
-
-# EOF
diff --git a/src/scitex/dsp/norm.py b/src/scitex/dsp/norm.py
deleted file mode 100755
index 676898c5d..000000000
--- a/src/scitex/dsp/norm.py
+++ /dev/null
@@ -1,33 +0,0 @@
-#!/usr/bin/env python3
-# Time-stamp: "2024-04-05 12:15:42 (ywatanabe)"
-
-try:
- import torch as _torch
-
- TORCH_AVAILABLE = True
-except ImportError:
- TORCH_AVAILABLE = False
- _torch = None
-
-from scitex.decorators import signal_fn as _signal_fn
-
-
-def _check_torch():
- if not TORCH_AVAILABLE:
- raise ImportError(
- "PyTorch is not installed. Please install with: pip install torch"
- )
-
-
-@_signal_fn
-def z(x, dim=-1):
- _check_torch()
- return (x - x.mean(dim=dim, keepdim=True)) / x.std(dim=dim, keepdim=True)
-
-
-@_signal_fn
-def minmax(x, amp=1.0, dim=-1, fn="mean"):
- _check_torch()
- MM = x.max(dim=dim, keepdims=True)[0].abs()
- mm = x.min(dim=dim, keepdims=True)[0].abs()
- return amp * x / _torch.maximum(MM, mm)
diff --git a/src/scitex/dsp/params.py b/src/scitex/dsp/params.py
deleted file mode 100755
index 2b22e350e..000000000
--- a/src/scitex/dsp/params.py
+++ /dev/null
@@ -1,51 +0,0 @@
-import numpy as np
-import pandas as pd
-
-BANDS = pd.DataFrame(
- data=np.array([[0.5, 4], [4, 8], [8, 10], [10, 13], [13, 32], [32, 75]]).T,
- index=["low_hz", "high_hz"],
- columns=["delta", "theta", "lalpha", "halpha", "beta", "gamma"],
-)
-
-EEG_MONTAGE_1020 = [
- "FP1",
- "F3",
- "C3",
- "P3",
- "O1",
- "FP2",
- "F4",
- "C4",
- "P4",
- "O2",
- "F7",
- "T7",
- "P7",
- "F8",
- "T8",
- "P8",
- "FZ",
- "CZ",
- "PZ",
-]
-
-EEG_MONTAGE_BIPOLAR_TRANVERSE = [
- # Frontal
- "FP1-FP2",
- "F7-F3",
- "F3-FZ",
- "FZ-F4",
- "F4-F8",
- # Central
- "T7-C3",
- "C3-CZ",
- "CZ-C4",
- "C4-T8",
- # Parietal
- "P7-P3",
- "P3-PZ",
- "PZ-P4",
- "P4-P8",
- # Occipital
- "O1-O2",
-]
diff --git a/src/scitex/dsp/reference.py b/src/scitex/dsp/reference.py
deleted file mode 100755
index d1c1b1b06..000000000
--- a/src/scitex/dsp/reference.py
+++ /dev/null
@@ -1,60 +0,0 @@
-#!/usr/bin/env python3
-# Time-stamp: "ywatanabe (2024-11-02 22:48:44)"
-# File: ./scitex_repo/src/scitex/dsp/reference.py
-
-try:
- import torch as _torch
-
- TORCH_AVAILABLE = True
-except ImportError:
- TORCH_AVAILABLE = False
- _torch = None
-
-from scitex.decorators import torch_fn as _torch_fn
-
-
-def _check_torch():
- if not TORCH_AVAILABLE:
- raise ImportError(
- "PyTorch is not installed. Please install with: pip install torch"
- )
-
-
-@_torch_fn
-def common_average(x, dim=-2):
- _check_torch()
- re_referenced = (x - x.mean(dim=dim, keepdims=True)) / x.std(dim=dim, keepdims=True)
- assert x.shape == re_referenced.shape
- return re_referenced
-
-
-@_torch_fn
-def random(x, dim=-2):
- _check_torch()
- idx_all = [slice(None)] * x.ndim
- idx_rand_dim = _torch.randperm(x.shape[dim])
- idx_all[dim] = idx_rand_dim
- y = x[idx_all]
- re_referenced = x - y
- assert x.shape == re_referenced.shape
- return re_referenced
-
-
-@_torch_fn
-def take_reference(x, tgt_indi, dim=-2):
- _check_torch()
- idx_all = [slice(None)] * x.ndim
- idx_all[dim] = tgt_indi
- ref = x[tuple(idx_all)].unsqueeze(dim)
- re_referenced = x - ref
- assert x.shape == re_referenced.shape
- return re_referenced
-
-
-if __name__ == "__main__":
- import scitex
-
- x, f, t = scitex.dsp.demo_sig()
- y = common_average(x)
-
-# EOF
diff --git a/src/scitex/dsp/template.py b/src/scitex/dsp/template.py
deleted file mode 100755
index 08760333d..000000000
--- a/src/scitex/dsp/template.py
+++ /dev/null
@@ -1,26 +0,0 @@
-#!./env/bin/python3
-# -*- coding: utf-8 -*-
-# Time-stamp: "2024-04-10 15:57:54 (ywatanabe)"
-
-
-# FUnctions
-
-if __name__ == "__main__":
- import sys
-
- import matplotlib.pyplot as plt
- import torch
-
- import scitex
-
- # Start
- CONFIG, sys.stdout, sys.stderr, plt, cc = scitex.session.start(sys, plt)
-
- # Close
- scitex.session.close(CONFIG)
-
- """
- /home/ywatanabe/proj/scitex/src/scitex/dsp/template.py
- """
-
- # EOF
diff --git a/src/scitex/dsp/utils/__init__.py b/src/scitex/dsp/utils/__init__.py
deleted file mode 100755
index e6e20bdd7..000000000
--- a/src/scitex/dsp/utils/__init__.py
+++ /dev/null
@@ -1,19 +0,0 @@
-#!/usr/bin/env python3
-"""Scitex utils module."""
-
-from ._differential_bandpass_filters import (
- build_bandpass_filters,
- init_bandpass_filters,
-)
-from ._ensure_3d import ensure_3d
-from ._ensure_even_len import ensure_even_len
-from ._zero_pad import _zero_pad_1d, zero_pad
-
-__all__ = [
- "_zero_pad_1d",
- "build_bandpass_filters",
- "ensure_3d",
- "ensure_even_len",
- "init_bandpass_filters",
- "zero_pad",
-]
diff --git a/src/scitex/dsp/utils/_differential_bandpass_filters.py b/src/scitex/dsp/utils/_differential_bandpass_filters.py
deleted file mode 100755
index 7b75ec7a7..000000000
--- a/src/scitex/dsp/utils/_differential_bandpass_filters.py
+++ /dev/null
@@ -1,148 +0,0 @@
-#!/usr/bin/env python3
-# Time-stamp: "2024-11-26 22:24:13 (ywatanabe)"
-# File: ./scitex_repo/src/scitex/dsp/utils/_differential_bandpass_filters.py
-
-THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/dsp/utils/_differential_bandpass_filters.py"
-
-import sys
-
-import matplotlib.pyplot as plt
-import numpy as np
-
-try:
- import torch
- import torch.nn as nn
-
- TORCH_AVAILABLE = True
-except ImportError:
- TORCH_AVAILABLE = False
- torch = None
- nn = None
-
-from scitex.decorators import torch_fn
-from scitex.gen._to_even import to_even
-from scitex.gen._to_odd import to_odd
-
-try:
- from torchaudio.prototype.functional import sinc_impulse_response
-
- TORCHAUDIO_AVAILABLE = True
-except ImportError:
- TORCHAUDIO_AVAILABLE = False
- sinc_impulse_response = None
-
-
-def _check_torch():
- if not TORCH_AVAILABLE:
- raise ImportError(
- "PyTorch is not installed. Please install with: pip install torch"
- )
-
-
-def _check_sinc_available():
- if sinc_impulse_response is None:
- raise ImportError(
- "sinc_impulse_response requires torchaudio.prototype.functional. "
- "Install torchaudio with: pip install torchaudio"
- )
-
-
-# Functions
-@torch_fn
-def init_bandpass_filters(
- sig_len,
- fs,
- pha_low_hz=2,
- pha_high_hz=20,
- pha_n_bands=30,
- amp_low_hz=60,
- amp_high_hz=160,
- amp_n_bands=50,
- cycle=3,
-):
- _check_sinc_available()
- # Learnable parameters
- pha_mids = nn.Parameter(torch.linspace(pha_low_hz, pha_high_hz, pha_n_bands))
- amp_mids = nn.Parameter(torch.linspace(amp_low_hz, amp_high_hz, amp_n_bands))
- filters = build_bandpass_filters(sig_len, fs, pha_mids, amp_mids, cycle)
- return filters, pha_mids, amp_mids
-
-
-@torch_fn
-def build_bandpass_filters(sig_len, fs, pha_mids, amp_mids, cycle):
- _check_sinc_available()
-
- def _define_freqs(mids, factor):
- lows = mids - mids / factor
- highs = mids + mids / factor
- return lows, highs
-
- def define_order(low_hz, fs, sig_len, cycle):
- order = cycle * int(fs // low_hz)
- order = order if 3 * order >= sig_len else (sig_len - 1) // 3
- order = to_even(order)
- return order
-
- def _calc_filters(lows_hz, highs_hz, fs, order):
- nyq = fs / 2.0
- order = to_odd(order)
- # lowpass filters
- irs_ll = sinc_impulse_response(lows_hz / nyq, window_size=order)
- irs_hh = sinc_impulse_response(highs_hz / nyq, window_size=order)
- irs = irs_ll - irs_hh
- return irs
-
- # Main
- pha_lows, pha_highs = _define_freqs(pha_mids, factor=4.0)
- amp_lows, amp_highs = _define_freqs(amp_mids, factor=8.0)
-
- lowest = min(pha_lows.min().item(), amp_lows.min().item())
- order = define_order(lowest, fs, sig_len, cycle)
-
- pha_bp_filters = _calc_filters(pha_lows, pha_highs, fs, order)
- amp_bp_filters = _calc_filters(amp_lows, amp_highs, fs, order)
- return torch.vstack([pha_bp_filters, amp_bp_filters])
-
-
-if __name__ == "__main__":
- import scitex
-
- # Start
- CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt, agg=True)
-
- # Demo signal
- freqs_hz = [10, 30, 100, 300]
- fs = 1024
- xx, tt, fs = scitex.dsp.demo_sig(fs=fs, freqs_hz=freqs_hz)
-
- # Main
- filters, pha_mids, amp_mids = init_bandpass_filters(xx.shape[-1], fs)
-
- filters.sum().backward() # OK. The filtering bands are trainable with backpropagation.
-
- # Update 'pha_mids' and 'amp_mids' in the forward method.
- # Then, re-build filters using optimized parameters like this:
- # self.filters = build_bandpass_filters(self.sig_len, self.fs, self.pha_mids, self.amp_mids, self.cycle)
-
- mids_all = np.concatenate(
- [pha_mids.detach().cpu().numpy(), amp_mids.detach().cpu().numpy()]
- )
-
- for i_filter in range(len(mids_all)):
- mid = mids_all[i_filter]
- fig = scitex.dsp.utils.filter.plot_filter_responses(
- filters[i_filter].detach().cpu().numpy(), fs, title=f"{mid:.1f} Hz"
- )
- scitex.io.save(
- fig,
- f"differentiable_bandpass_filter_reponses_filter#{i_filter:03d}_{mid:.1f}_Hz.png",
- )
- # plt.show()
-
-# EOF
-
-"""
-python -m scitex.dsp.utils._differential_bandpass_filters
-"""
-
-# EOF
diff --git a/src/scitex/dsp/utils/_ensure_3d.py b/src/scitex/dsp/utils/_ensure_3d.py
deleted file mode 100755
index 08f48786a..000000000
--- a/src/scitex/dsp/utils/_ensure_3d.py
+++ /dev/null
@@ -1,18 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-# Time-stamp: "2024-11-05 01:04:03 (ywatanabe)"
-# File: ./scitex_repo/src/scitex/dsp/utils/_ensure_3d.py
-
-from scitex.decorators import torch_fn
-
-
-@torch_fn
-def ensure_3d(x):
- if x.ndim == 1: # assumes (seq_len,)
- x = x.unsqueeze(0).unsqueeze(0)
- elif x.ndim == 2: # assumes (batch_siize, seq_len)
- x = x.unsqueeze(1)
- return x
-
-
-# EOF
diff --git a/src/scitex/dsp/utils/_ensure_even_len.py b/src/scitex/dsp/utils/_ensure_even_len.py
deleted file mode 100755
index 294e43927..000000000
--- a/src/scitex/dsp/utils/_ensure_even_len.py
+++ /dev/null
@@ -1,10 +0,0 @@
-#!./env/bin/python3
-# -*- coding: utf-8 -*-
-# Time-stamp: "2024-04-10 11:59:49 (ywatanabe)"
-
-
-def ensure_even_len(x):
- if x.shape[-1] % 2 == 0:
- return x
- else:
- return x[..., :-1]
diff --git a/src/scitex/dsp/utils/_zero_pad.py b/src/scitex/dsp/utils/_zero_pad.py
deleted file mode 100755
index a0e479744..000000000
--- a/src/scitex/dsp/utils/_zero_pad.py
+++ /dev/null
@@ -1,62 +0,0 @@
-#!/usr/bin/env python3
-# Time-stamp: "2024-11-26 10:30:34 (ywatanabe)"
-# File: ./scitex_repo/src/scitex/dsp/utils/_zero_pad.py
-
-THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/dsp/utils/_zero_pad.py"
-
-import numpy as np
-
-try:
- import torch
- import torch.nn.functional as F
-
- TORCH_AVAILABLE = True
-except ImportError:
- TORCH_AVAILABLE = False
- torch = None
- F = None
-
-
-def _check_torch():
- if not TORCH_AVAILABLE:
- raise ImportError(
- "PyTorch is not installed. Please install with: pip install torch"
- )
-
-
-def _zero_pad_1d(x, target_length):
- """Zero pad a 1D tensor to target length."""
- _check_torch()
- if not isinstance(x, torch.Tensor):
- x = torch.tensor(x)
- padding_needed = target_length - len(x)
- padding_left = padding_needed // 2
- padding_right = padding_needed - padding_left
- return F.pad(x, (padding_left, padding_right), "constant", 0)
-
-
-def zero_pad(xs, dim=0):
- """Zero pad a list of arrays to the same length.
-
- Args:
- xs: List of tensors or arrays
- dim: Dimension to stack along
-
- Returns:
- Stacked tensor with zero padding
- """
- # Convert to tensors if needed
- tensors = []
- for x in xs:
- if isinstance(x, np.ndarray):
- tensors.append(torch.tensor(x))
- elif isinstance(x, torch.Tensor):
- tensors.append(x)
- else:
- tensors.append(torch.tensor(x))
-
- max_len = max([len(x) for x in tensors])
- return torch.stack([_zero_pad_1d(x, max_len) for x in tensors], dim=dim)
-
-
-# EOF
diff --git a/src/scitex/dsp/utils/filter.py b/src/scitex/dsp/utils/filter.py
deleted file mode 100755
index 572163915..000000000
--- a/src/scitex/dsp/utils/filter.py
+++ /dev/null
@@ -1,408 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-# Time-stamp: "2024-11-03 07:24:43 (ywatanabe)"
-# File: ./scitex_repo/src/scitex/dsp/utils/filter.py
-
-import matplotlib.pyplot as plt
-import numpy as np
-from scipy.signal import firwin, freqz
-
-from scitex.decorators import numpy_fn
-from scitex.gen._to_even import to_even
-
-
-@numpy_fn
-def design_filter(sig_len, fs, low_hz=None, high_hz=None, cycle=3, is_bandstop=False):
- """
- Designs a Finite Impulse Response (FIR) filter based on the specified parameters.
-
- Arguments:
- - sig_len (int): Length of the signal for which the filter is being designed.
- - fs (int): Sampling frequency of the signal.
- - low_hz (float, optional): Low cutoff frequency for the filter. Required for lowpass and bandpass filters.
- - high_hz (float, optional): High cutoff frequency for the filter. Required for highpass and bandpass filters.
- - cycle (int, optional): Number of cycles to use in determining the filter order. Defaults to 3.
- - is_bandstop (bool, optional): Specifies if the filter should be a bandstop filter. Defaults to False.
-
- Returns:
- - The coefficients of the designed FIR filter.
-
- Raises:
- - FilterParameterError: If the provided parameters are invalid.
- """
-
- class FilterParameterError(Exception):
- """Custom exception for invalid filter parameters."""
-
- pass
-
- def estimate_filter_type(low_hz=None, high_hz=None, is_bandstop=False):
- """
- Estimates the filter type based on the provided low and high cutoff frequencies,
- and whether a bandstop filter is desired. Raises an exception for invalid configurations.
- """
- if low_hz is not None and low_hz < 0:
- raise FilterParameterError("low_hz must be non-negative.")
- if high_hz is not None and high_hz < 0:
- raise FilterParameterError("high_hz must be non-negative.")
- if low_hz is not None and high_hz is not None and low_hz >= high_hz:
- raise FilterParameterError(
- "low_hz must be less than high_hz for valid configurations."
- )
-
- if low_hz is not None and high_hz is not None:
- return "bandstop" if is_bandstop else "bandpass"
- elif low_hz is not None:
- return "lowpass"
- elif high_hz is not None:
- return "highpass"
- else:
- raise FilterParameterError(
- "At least one of low_hz or high_hz must be provided."
- )
-
- def determine_cutoff_frequencies(filter_mode, low_hz, high_hz):
- if filter_mode in ["lowpass", "highpass"]:
- cutoff = low_hz if filter_mode == "lowpass" else high_hz
- else: # 'bandpass' or 'bandstop'
- cutoff = [low_hz, high_hz]
- return cutoff
-
- def determine_low_freq(filter_mode, low_hz, high_hz):
- if filter_mode in ["lowpass", "bandstop"]:
- low_freq = low_hz
- else: # 'highpass' or 'bandpass'
- low_freq = high_hz if filter_mode == "highpass" else min(low_hz, high_hz)
- return low_freq
-
- def determine_order(filter_mode, fs, low_freq, sig_len, cycle):
- order = cycle * int((fs // low_freq))
- if 3 * order < sig_len:
- order = (sig_len - 1) // 3
- order = to_even(order)
- return order
-
- fs = int(fs)
- low_hz = float(low_hz) if low_hz is not None else low_hz
- high_hz = float(high_hz) if high_hz is not None else high_hz
- filter_mode = estimate_filter_type(low_hz, high_hz, is_bandstop)
- cutoff = determine_cutoff_frequencies(filter_mode, low_hz, high_hz)
- low_freq = determine_low_freq(filter_mode, low_hz, high_hz)
- order = determine_order(filter_mode, fs, low_freq, sig_len, cycle)
- numtaps = order + 1
-
- try:
- h = firwin(
- numtaps=numtaps,
- cutoff=cutoff,
- pass_zero=(filter_mode in ["highpass", "bandstop"]),
- window="hamming",
- fs=fs,
- scale=True,
- )
- except Exception as e:
- print(e)
- import ipdb
-
- ipdb.set_trace()
-
- return h
-
-
-@numpy_fn
-def plot_filter_responses(filter, fs, worN=8000, title=None):
- """
- Plots the impulse and frequency response of an FIR filter using numpy arrays.
-
- Parameters:
- - filter_coeffs (numpy.ndarray): The filter coefficients as a numpy array.
- - fs (int): The sampling frequency in Hz.
- - title (str, optional): The title of the plot. Defaults to None.
-
- Returns:
- - matplotlib.figure.Figure: The figure object containing the impulse and frequency response plots.
- """
- import scitex
-
- ww, hh = freqz(filter, worN=worN, fs=fs)
-
- fig, axes = scitex.plt.subplots(ncols=2)
- fig.suptitle(title)
-
- # Impulse Responses of FIR Filter
- ax = axes[0]
- ax.plot(filter)
- ax.set_title("Impulse Responses of FIR Filter")
- ax.set_xlabel("Tap Number")
- ax.set_ylabel("Amplitude")
-
- # Frequency Response of FIR Filter
- ax = axes[1]
- ax.plot(ww, 20 * np.log10(abs(hh) + 1e-5))
- ax.set_title("Frequency Response of FIR Filter")
- ax.set_xlabel("Frequency [Hz]")
- ax.set_ylabel("Gain [dB]")
-
- return fig
-
-
-if __name__ == "__main__":
- import scitex
-
- # Example usage
- xx, tt, fs = scitex.dsp.demo_sig()
- batch_size, n_chs, seq_len = xx.shape
-
- lp_filter = design_filter(seq_len, fs, low_hz=30, high_hz=None)
- hp_filter = design_filter(seq_len, fs, low_hz=None, high_hz=70)
- bp_filter = design_filter(seq_len, fs, low_hz=30, high_hz=70)
- bs_filter = design_filter(seq_len, fs, low_hz=30, high_hz=70, is_bandstop=True)
-
- fig = plot_filter_responses(lp_filter, fs, title="Lowpass Filter")
- fig = plot_filter_responses(hp_filter, fs, title="Highpass Filter")
- fig = plot_filter_responses(bp_filter, fs, title="Bandpass Filter")
- fig = plot_filter_responses(bs_filter, fs, title="Bandstop Filter")
-
- # Figure
- fig, axes = plt.subplots(nrows=4, ncols=2)
-
- # Time domain expressions??
- axes[0, 0].plot(lp_filter, label="Lowpass Filter")
- axes[1, 0].plot(hp_filter, label="Highpass Filter")
- axes[2, 0].plot(bp_filter, label="Bandpass Filter")
- axes[3, 0].plot(bs_filter, label="Bandstop Filter")
- # fig.suptitle("Impulse Responses of FIR Filter")
- # fig.supxlabel("Tap Number")
- # fig.supylabel("Amplitude")
- # fig.show()
-
- # Frequency response of the filters
- w, h_lp = freqz(lp_filter, worN=8000, fs=fs)
- w, h_hp = freqz(hp_filter, worN=8000, fs=fs)
- w, h_bp = freqz(bp_filter, worN=8000, fs=fs)
- w, h_bs = freqz(bs_filter, worN=8000, fs=fs)
-
- # Plotting the frequency response
- axes[0, 1].plot(w, 20 * np.log10(abs(h_lp)), label="Lowpass Filter")
- axes[1, 1].plot(w, 20 * np.log10(abs(h_hp)), label="Highpass Filter")
- axes[2, 1].plot(w, 20 * np.log10(abs(h_bp)), label="Bandpass Filter")
- axes[3, 1].plot(w, 20 * np.log10(abs(h_bs)), label="Bandstop Filter")
- # plt.title("Frequency Response of FIR Filters")
- # plt.xlabel("Frequency (Hz)")
- # plt.ylabel("Gain (dB)")
- # plt.grid(True)
- # plt.legend(loc="best")
- # plt.show()
- fig.tight_layout()
- plt.show()
-
-# @torch_fn
-# def bandpass(x, filt):
-# assert x.ndim == 3
-# xf = F.conv1d(
-# x.reshape(-1, x.shape[-1]).unsqueeze(1),
-# filt.unsqueeze(0).unsqueeze(0),
-# padding="same",
-# ).reshape(*x.shape)
-# assert x.shape == xf.shape
-# return xf
-
-# def define_bandpass_filters(seq_len, fs, freq_bands, cycle=3):
-# """
-# Defines Finite Impulse Response (FIR) filters.
-# b: The filter coefficients (or taps) of the FIR filters
-# a: The denominator coefficients of the filter's transfer function. However, FIR filters have a transfer function with a denominator equal to 1 (since they are all-zero filters with no poles).
-# """
-# # Parameters
-# n_freqs = len(freq_bands)
-# nyq = fs / 2.0
-
-# bs = []
-# for ll, hh in freq_bands:
-# wn = np.array([ll, hh]) / nyq
-# order = define_fir_order(fs, seq_len, ll, cycle=cycle)
-# bs.append(fir1(order, wn)[0])
-# return bs
-
-# def define_fir_order(fs, sizevec, flow, cycle=3):
-# """
-# Calculate filter order.
-# """
-# if cycle is None:
-# filtorder = 3 * np.fix(fs / flow)
-# else:
-# filtorder = cycle * (fs // flow)
-
-# if sizevec < 3 * filtorder:
-# filtorder = (sizevec - 1) // 3
-
-# return int(filtorder)
-
-# def n_odd_fcn(f, o, w, l):
-# """Odd case."""
-# # Variables :
-# b0 = 0
-# m = np.array(range(int(l + 1)))
-# k = m[1 : len(m)]
-# b = np.zeros(k.shape)
-
-# # Run Loop :
-# for s in range(0, len(f), 2):
-# m = (o[s + 1] - o[s]) / (f[s + 1] - f[s])
-# b1 = o[s] - m * f[s]
-# b0 = b0 + (
-# b1 * (f[s + 1] - f[s])
-# + m / 2 * (f[s + 1] * f[s + 1] - f[s] * f[s])
-# ) * abs(np.square(w[round((s + 1) / 2)]))
-# b = b + (
-# m
-# / (4 * np.pi * np.pi)
-# * (
-# np.cos(2 * np.pi * k * f[s + 1])
-# - np.cos(2 * np.pi * k * f[s])
-# )
-# / (k * k)
-# ) * abs(np.square(w[round((s + 1) / 2)]))
-# b = b + (
-# f[s + 1] * (m * f[s + 1] + b1) * np.sinc(2 * k * f[s + 1])
-# - f[s] * (m * f[s] + b1) * np.sinc(2 * k * f[s])
-# ) * abs(np.square(w[round((s + 1) / 2)]))
-
-# b = np.insert(b, 0, b0)
-# a = (np.square(w[0])) * 4 * b
-# a[0] = a[0] / 2
-# aud = np.flipud(a[1 : len(a)]) / 2
-# a2 = np.insert(aud, len(aud), a[0])
-# h = np.concatenate((a2, a[1:] / 2))
-
-# return h
-
-# def n_even_fcn(f, o, w, l):
-# """Even case."""
-# # Variables :
-# k = np.array(range(0, int(l) + 1, 1)) + 0.5
-# b = np.zeros(k.shape)
-
-# # # Run Loop :
-# for s in range(0, len(f), 2):
-# m = (o[s + 1] - o[s]) / (f[s + 1] - f[s])
-# b1 = o[s] - m * f[s]
-# b = b + (
-# m
-# / (4 * np.pi * np.pi)
-# * (
-# np.cos(2 * np.pi * k * f[s + 1])
-# - np.cos(2 * np.pi * k * f[s])
-# )
-# / (k * k)
-# ) * abs(np.square(w[round((s + 1) / 2)]))
-# b = b + (
-# f[s + 1] * (m * f[s + 1] + b1) * np.sinc(2 * k * f[s + 1])
-# - f[s] * (m * f[s] + b1) * np.sinc(2 * k * f[s])
-# ) * abs(np.square(w[round((s + 1) / 2)]))
-
-# a = (np.square(w[0])) * 4 * b
-# h = 0.5 * np.concatenate((np.flipud(a), a))
-
-# return h
-
-# def firls(n, f, o):
-# # Variables definition :
-# w = np.ones(round(len(f) / 2))
-# n += 1
-# f /= 2
-# lo = (n - 1) / 2
-
-# nodd = bool(n % 2)
-
-# if nodd: # Odd case
-# h = n_odd_fcn(f, o, w, lo)
-# else: # Even case
-# h = n_even_fcn(f, o, w, lo)
-
-# return h
-
-# def fir1(n, wn):
-# # Variables definition :
-# nbands = len(wn) + 1
-# ff = np.array((0, wn[0], wn[0], wn[1], wn[1], 1))
-
-# f0 = np.mean(ff[2:4])
-# lo = n + 1
-
-# mags = np.array(range(nbands)).reshape(1, -1) % 2
-# aa = np.ravel(np.tile(mags, (2, 1)), order="F")
-
-# # Get filter coefficients :
-# h = firls(lo - 1, ff, aa)
-
-# # Apply a window to coefficients :
-# wind = np.hamming(lo)
-# b = h * wind
-# c = np.exp(-1j * 2 * np.pi * (f0 / 2) * np.array(range(lo)))
-# b /= abs(c @ b)
-
-# return b, 1
-
-# def apply_filters(x, filts):
-# """
-# x: (batch_size, n_chs, seq_len)
-# filts: (n_filts, seq_len_filt)
-# """
-# assert x.ndims == 3
-# assert filts.ndims == 2
-# batch_size, n_chs, n_time = x.shape
-# x = x.reshape(-1, x.shape[-1]).unsqueeze(1)
-# filts = filts.unsqueeze(1)
-# n_filts = len(filts)
-# return F.conv1d(x, filts, padding="same").reshape(
-# batch_size, n_chs, n_filts, n_time
-# )
-
-# if __name__ == "__main__":
-# import torch
-# import torch.nn.functional as F
-
-# plt, CC = scitex.plt.configure_mpl(plt)
-
-# # Demo Signal
-# freqs_hz = [10, 30, 100]
-# xx, tt, fs = scitex.dsp.demo_sig(freqs_hz=freqs_hz, sig_type="periodic")
-# x = xx
-
-# seq_len = x.shape[-1]
-# freq_bands = np.array([[20, 70], [3.0, 4.0]])
-
-# # Plots the figure
-# fig, ax = scitex.plt.subplots()
-# # ax.plot(b, label="bandpass filter")
-
-# # Bandpass Filtering
-# filters = define_bandpass_filters(seq_len, fs, freq_bands, cycle=3)
-# i_filt = 0
-# # xf = bandpass(xx, filters[i_filt])
-
-# # Plots the signals
-# fig, axes = scitex.plt.subplots(nrows=2, sharex=True, sharey=True)
-# axes[0].plot(tt, xx[0, 0], label="orig")
-# axes[1].plot(tt, xf[0, 0], label="orig")
-# [ax.legend(loc="upper left") for ax in axes]
-
-# # Plots PSDs
-# psd_xx, ff_xx = scitex.dsp.psd(xx.numpy(), fs)
-# psd_xf, ff_xf = scitex.dsp.psd(xf.numpy(), fs)
-
-# fig, axes = scitex.plt.subplots(nrows=2, sharex=True, sharey=True)
-# axes[0].plot(ff_xx, psd_xx[0, 0], label="orig")
-# axes[1].plot(ff_xf, psd_xf[0, 0], label="filted")
-# [ax.legend(loc="upper left") for ax in axes]
-# plt.show()
-
-# # Multiple Filters in a parallel computation
-# x = torch.randn(33, 32, 30)
-# filters = torch.randn(20, 5)
-
-# y = apply_filters(x, filters)
-# print(y.shape) # (33, 32, 20, 30)
-
-# EOF
diff --git a/src/scitex/dsp/utils/pac.py b/src/scitex/dsp/utils/pac.py
deleted file mode 100755
index ffbdab6a4..000000000
--- a/src/scitex/dsp/utils/pac.py
+++ /dev/null
@@ -1,178 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-# Timestamp: "2025-05-26 06:26:29 (ywatanabe)"
-# File: /ssh:ywatanabe@sp:/home/ywatanabe/proj/scitex_repo/src/scitex/dsp/utils/pac.py
-# ----------------------------------------
-import os
-
-__FILE__ = "./src/scitex/dsp/utils/pac.py"
-__DIR__ = os.path.dirname(__FILE__)
-# ----------------------------------------
-
-#! ./env/bin/python3
-# Time-stamp: "2024-04-16 17:07:27"
-
-
-"""
-This script does XYZ.
-"""
-
-# Imports
-import sys
-
-import matplotlib.pyplot as plt
-import numpy as np
-import tensorpac
-
-import scitex
-
-
-# Functions
-def calc_pac_with_tensorpac(xx, fs, t_sec, i_batch=0, i_ch=0):
- # Morlet's Wavelet Transfrmation
- p = tensorpac.Pac(f_pha="hres", f_amp="mres", dcomplex="wavelet")
-
- # Bandpass Filtering and Hilbert Transformation
- phases = p.filter(fs, xx[i_batch, i_ch], ftype="phase", n_jobs=1) # (50, 20, 2048)
- amplitudes = p.filter(
- fs, xx[i_batch, i_ch], ftype="amplitude", n_jobs=1
- ) # (50, 20, 2048)
-
- # Calculates xpac
- k = 2
- p.idpac = (k, 0, 0)
- xpac = p.fit(phases, amplitudes) # (50, 50, 20)
- pac = xpac.mean(axis=-1) # (50, 50)
-
- freqs_amp = p.f_amp.mean(axis=-1)
- freqs_pha = p.f_pha.mean(axis=-1)
-
- pac = pac.T # (amp, pha) -> (pha, amp)
-
- return phases, amplitudes, freqs_pha, freqs_amp, pac
-
-
-def plot_PAC_scitex_vs_tensorpac(pac_scitex, pac_tp, freqs_pha, freqs_amp):
- assert pac_scitex.shape == pac_tp.shape
-
- # Plots
- fig, axes = scitex.plt.subplots(ncols=3) # , sharex=True, sharey=True
-
- # To align scalebars
- vmin = min(np.min(pac_scitex), np.min(pac_tp), np.min(pac_scitex - pac_tp))
- vmax = max(np.max(pac_scitex), np.max(pac_tp), np.max(pac_scitex - pac_tp))
-
- # scitex version
- ax = axes[0]
- ax.imshow2d(
- pac_scitex,
- cbar=False,
- vmin=vmin,
- vmax=vmax,
- )
- ax.set_title("scitex")
-
- # Tensorpac
- ax = axes[1]
- ax.imshow2d(
- pac_tp,
- cbar=False,
- vmin=vmin,
- vmax=vmax,
- )
- ax.set_title("Tensorpac")
-
- # Diff.
- ax = axes[2]
- ax.imshow2d(
- pac_scitex - pac_tp,
- cbar_label="PAC values",
- cbar_shrink=0.5,
- vmin=vmin,
- vmax=vmax,
- )
- ax.set_title(f"Difference\n(scitex - Tensorpac)")
-
- # for ax in axes:
- # ax.set_ticks(
- # x_vals=freqs_pha,
- # # y_vals=freqs_amp,
- # )
- # # ax.set_n_ticks()
-
- fig.suptitle("PAC (MI) values")
- fig.supxlabel("Frequency for phase [Hz]")
- fig.supylabel("Frequency for amplitude [Hz]")
-
- return fig
-
-
-# Snake_case alias for consistency
-def plot_pac_scitex_vs_tensorpac(pac_scitex, pac_tp, freqs_pha, freqs_amp):
- """
- Plot comparison between SciTeX and Tensorpac phase-amplitude coupling results.
-
- This is an alias for plot_PAC_scitex_vs_tensorpac with snake_case naming.
-
- Parameters
- ----------
- pac_scitex : array-like
- PAC values from SciTeX
- pac_tp : array-like
- PAC values from Tensorpac
- freqs_pha : array-like
- Phase frequencies
- freqs_amp : array-like
- Amplitude frequencies
-
- Returns
- -------
- fig : matplotlib.figure.Figure
- The generated figure
- """
- return plot_PAC_scitex_vs_tensorpac(pac_scitex, pac_tp, freqs_pha, freqs_amp)
-
-
-if __name__ == "__main__":
- import torch
-
- # Start
- CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt)
-
- # Parameters
- FS = 512
- T_SEC = 4
-
- xx, tt, fs = scitex.dsp.demo_sig(
- batch_size=2,
- n_chs=2,
- n_segments=2,
- fs=FS,
- t_sec=T_SEC,
- sig_type="tensorpac",
- )
-
- # scitex
- pac_scitex, freqs_pha, freqs_amp = scitex.dsp.pac(
- xx, fs, batch_size=2, pha_n_bands=50, amp_n_bands=30
- )
- i_batch, i_epoch = 0, 0
- pac_scitex = pac_scitex[i_batch, i_epoch]
-
- # Tensorpac
- phases, amplitudes, freqs_pha, freqs_amp, pac_tp = calc_pac_with_tensorpac(
- xx, fs, T_SEC, i_batch=0, i_ch=0
- )
-
- # Plots
- fig = plot_PAC_scitex_vs_tensorpac(pac_scitex, pac_tp, freqs_pha, freqs_amp)
- plt.show()
-
- # Close
- scitex.session.close(CONFIG)
-
-"""
-/home/ywatanabe/proj/entrance/scitex/dsp/utils/pac.py
-"""
-
-# EOF