From 63fe413f6f96d5725a8d151f5f38be3d3dabeb9d Mon Sep 17 00:00:00 2001 From: Yusuke Watanabe Date: Mon, 27 Apr 2026 09:05:11 +0900 Subject: [PATCH] extract(nn): bridge scitex.nn via sys.modules alias to scitex-nn MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit scitex.nn (2,664 LOC, neural-network blocks) extracted to standalone scitex-nn package: https://github.com/ywatanabe1989/scitex-nn - Replace src/scitex/nn/ with 12-line sys.modules-alias __init__.py - Collapse [nn] extra (9 transitive deps) to single scitex-nn>=0.1.0 - scitex.nn is scitex_nn: True (verified) - Standalone tests: 458 pass / 2 fail (flaky FreqGainChanger random) / 38 skip Decoupling in scitex-nn: - scitex.{decorators,gen} → scitex_decorators / scitex_gen direct imports - scitex.dsp.utils helpers vendored under _vendor_dsp_utils/ with prefer-real fallback - scitex.nn.X self-references → scitex_nn.X Co-Authored-By: Claude Opus 4.7 (1M context) --- pyproject.toml | 19 +- src/scitex/nn/_AxiswiseDropout.py | 27 -- src/scitex/nn/_BNet.py | 131 ------- src/scitex/nn/_BNet_Res.py | 164 -------- src/scitex/nn/_ChannelGainChanger.py | 45 --- src/scitex/nn/_DropoutChannels.py | 52 --- src/scitex/nn/_Filters.py | 491 ------------------------ src/scitex/nn/_FreqGainChanger.py | 111 ------ src/scitex/nn/_GaussianFilter.py | 48 --- src/scitex/nn/_Hilbert.py | 111 ------ src/scitex/nn/_MNet_1000.py | 158 -------- src/scitex/nn/_ModulationIndex.py | 224 ----------- src/scitex/nn/_PAC.py | 416 -------------------- src/scitex/nn/_PSD.py | 39 -- src/scitex/nn/_ResNet1D.py | 120 ------ src/scitex/nn/_SpatialAttention.py | 26 -- src/scitex/nn/_Spectrogram.py | 170 -------- src/scitex/nn/_SwapChannels.py | 52 --- src/scitex/nn/_TransposeLayer.py | 19 - src/scitex/nn/_Wavelet.py | 185 --------- src/scitex/nn/__init__.py | 82 +--- src/scitex/nn/_skills/SKILL.md | 109 ------ src/scitex/nn/_skills/architectures.md | 175 --------- src/scitex/nn/_skills/augmentation.md | 130 ------- src/scitex/nn/_skills/filters.md | 158 -------- src/scitex/nn/_skills/pac.md | 132 ------- src/scitex/nn/_skills/spectral.md | 142 ------- src/scitex/nn/_skills/utility-layers.md | 96 ----- 28 files changed, 11 insertions(+), 3621 deletions(-) delete mode 100755 src/scitex/nn/_AxiswiseDropout.py delete mode 100755 src/scitex/nn/_BNet.py delete mode 100755 src/scitex/nn/_BNet_Res.py delete mode 100755 src/scitex/nn/_ChannelGainChanger.py delete mode 100755 src/scitex/nn/_DropoutChannels.py delete mode 100755 src/scitex/nn/_Filters.py delete mode 100755 src/scitex/nn/_FreqGainChanger.py delete mode 100755 src/scitex/nn/_GaussianFilter.py delete mode 100755 src/scitex/nn/_Hilbert.py delete mode 100755 src/scitex/nn/_MNet_1000.py delete mode 100755 src/scitex/nn/_ModulationIndex.py delete mode 100755 src/scitex/nn/_PAC.py delete mode 100755 src/scitex/nn/_PSD.py delete mode 100755 src/scitex/nn/_ResNet1D.py delete mode 100755 src/scitex/nn/_SpatialAttention.py delete mode 100755 src/scitex/nn/_Spectrogram.py delete mode 100755 src/scitex/nn/_SwapChannels.py delete mode 100755 src/scitex/nn/_TransposeLayer.py delete mode 100755 src/scitex/nn/_Wavelet.py delete mode 100644 src/scitex/nn/_skills/SKILL.md delete mode 100644 src/scitex/nn/_skills/architectures.md delete mode 100644 src/scitex/nn/_skills/augmentation.md delete mode 100644 src/scitex/nn/_skills/filters.md delete mode 100644 src/scitex/nn/_skills/pac.md delete mode 100644 src/scitex/nn/_skills/spectral.md delete mode 100644 src/scitex/nn/_skills/utility-layers.md diff --git a/pyproject.toml b/pyproject.toml index bdfecd610..3e6cdcf1e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -457,24 +457,7 @@ msword = [ # NN Module - Neural Networks # Use: pip install scitex[nn] -nn = [ - "seaborn", - "ipdb", - "matplotlib", - "joblib", - "tensorpac", - "ruamel.yaml", - "h5py", - "readchar", - "xarray", - # # Heavy dependencies handled by _AVAILABLE flags - # "torch", - # "torchaudio", - # "torchsummary", - # "julius", - # "mne", - # "ripple_detection", -] +nn = ["scitex-nn>=0.1.0"] # Notebook Module - Jupyter notebook utilities # Use: pip install scitex[notebook] diff --git a/src/scitex/nn/_AxiswiseDropout.py b/src/scitex/nn/_AxiswiseDropout.py deleted file mode 100755 index cd75112e7..000000000 --- a/src/scitex/nn/_AxiswiseDropout.py +++ /dev/null @@ -1,27 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-03-30 07:27:27 (ywatanabe)" - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class AxiswiseDropout(nn.Module): - def __init__(self, dropout_prob=0.5, dim=1): - super(AxiswiseDropout, self).__init__() - self.dropout_prob = dropout_prob - self.dim = dim - - def forward(self, x): - if self.training: - sizes = [s if i == self.dim else 1 for i, s in enumerate(x.size())] - dropout_mask = F.dropout( - torch.ones(*sizes, device=x.device, dtype=x.dtype), - self.dropout_prob, - True, - ) - - # Expand the mask to the size of the input tensor and apply it - return x * dropout_mask.expand_as(x) - return x diff --git a/src/scitex/nn/_BNet.py b/src/scitex/nn/_BNet.py deleted file mode 100755 index 3f45f9956..000000000 --- a/src/scitex/nn/_BNet.py +++ /dev/null @@ -1,131 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2023-05-15 16:44:27 (ywatanabe)" - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchsummary import summary - -from ._ChannelGainChanger import ChannelGainChanger -from ._DropoutChannels import DropoutChannels -from ._FreqGainChanger import FreqGainChanger -from ._MNet_1000 import MNet_1000 - -# Import specific nn modules to avoid circular imports -from ._SpatialAttention import SpatialAttention -from ._SwapChannels import SwapChannels - - -class BHead(nn.Module): - def __init__(self, n_chs_in, n_chs_out): - super().__init__() - self.sa = SpatialAttention(n_chs_in) - self.conv11 = nn.Conv1d( - in_channels=n_chs_in, out_channels=n_chs_out, kernel_size=1 - ) - - def forward(self, x): - x = self.sa(x) - x = self.conv11(x) - return x - - -class BNet(nn.Module): - def __init__(self, BNet_config, MNet_config): - super().__init__() - self.dummy_param = nn.Parameter(torch.empty(0)) - N_VIRTUAL_CHS = 32 - - self.sc = SwapChannels() - self.dc = DropoutChannels(dropout=0.01) - self.fgc = FreqGainChanger(BNet_config["n_bands"], BNet_config["SAMP_RATE"]) - self.heads = nn.ModuleList( - [ - BHead(n_ch, N_VIRTUAL_CHS).to(self.dummy_param.device) - for n_ch in BNet_config["n_chs"] - ] - ) - - self.cgcs = [ChannelGainChanger(n_ch) for n_ch in BNet_config["n_chs"]] - # self.cgc = ChannelGainChanger(N_VIRTUAL_CHS) - - MNet_config["n_chs"] = N_VIRTUAL_CHS # BNet_config["n_chs"] # override - self.MNet = MNet_1000(MNet_config) - - self.fcs = nn.ModuleList( - [ - nn.Sequential( - # nn.Linear(N_FC_IN, config["n_fc1"]), - nn.Mish(), - nn.Dropout(BNet_config["d_ratio1"]), - nn.Linear(BNet_config["n_fc1"], BNet_config["n_fc2"]), - nn.Mish(), - nn.Dropout(BNet_config["d_ratio2"]), - nn.Linear(BNet_config["n_fc2"], BNet_config["n_classes"][i_head]), - ) - for i_head, _ in enumerate(range(len(BNet_config["n_chs"]))) - ] - ) - - @staticmethod - def _znorm_along_the_last_dim(x): - return (x - x.mean(dim=-1, keepdims=True)) / x.std(dim=-1, keepdims=True) - - def forward(self, x, i_head): - x = self._znorm_along_the_last_dim(x) - # x = self.sc(x) - x = self.dc(x) - x = self.fgc(x) - x = self.cgcs[i_head](x) - x = self.heads[i_head](x) - import ipdb - - ipdb.set_trace() - # x = self.cgc(x) - x = self.MNet.forward_bb(x) - x = self.fcs[i_head](x) - return x - - -# BNet_config = { -# "n_chs": 32, -# "n_bands": 6, -# "SAMP_RATE": 1000, -# } -BNet_config = { - "n_bands": 6, - "SAMP_RATE": 250, - # "n_chs": 270, - "n_fc1": 1024, - "d_ratio1": 0.85, - "n_fc2": 256, - "d_ratio2": 0.85, -} - - -if __name__ == "__main__": - ## Demo data - # MEG - BS, N_CHS, SEQ_LEN = 16, 160, 1000 - x_MEG = torch.rand(BS, N_CHS, SEQ_LEN).cuda() - # EEG - BS, N_CHS, SEQ_LEN = 16, 19, 1000 - x_EEG = torch.rand(BS, N_CHS, SEQ_LEN).cuda() - - # model = MNetBackBorn(scitex.nn.MNet_config).cuda() - # model(x_MEG) - # Model - BNet_config["n_chs"] = [160, 19] - BNet_config["n_classes"] = [2, 4] - model = BNet(BNet_config, scitex.nn.MNet_config).cuda() - - # MEG - y = model(x_MEG, 0) - y = model(x_EEG, 1) - - # # EEG - # y = model(x_EEG) - - y.sum().backward() diff --git a/src/scitex/nn/_BNet_Res.py b/src/scitex/nn/_BNet_Res.py deleted file mode 100755 index cbd9ec9a2..000000000 --- a/src/scitex/nn/_BNet_Res.py +++ /dev/null @@ -1,164 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2023-05-15 17:09:58 (ywatanabe)" - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchsummary import summary - -import scitex - - -class BHead(nn.Module): - def __init__(self, n_chs_in, n_chs_out): - super().__init__() - self.sa = scitex.nn.SpatialAttention(n_chs_in) - self.conv11 = nn.Conv1d( - in_channels=n_chs_in, out_channels=n_chs_out, kernel_size=1 - ) - - def forward(self, x): - x = self.sa(x) - x = self.conv11(x) - return x - - -class BNet(nn.Module): - def __init__(self, BNet_config, MNet_config): - super().__init__() - self.dummy_param = nn.Parameter(torch.empty(0)) - # N_VIRTUAL_CHS = 32 - # "n_virtual_chs":16, - - self.sc = scitex.nn.SwapChannels() - self.dc = scitex.nn.DropoutChannels(dropout=0.01) - self.fgc = scitex.nn.FreqGainChanger( - BNet_config["n_bands"], BNet_config["SAMP_RATE"] - ) - self.heads = nn.ModuleList( - [ - BHead(n_ch, BNet_config["n_virtual_chs"]).to(self.dummy_param.device) - for n_ch in BNet_config["n_chs_of_modalities"] - ] - ) - - self.cgcs = [ - scitex.nn.ChannelGainChanger(n_ch) - for n_ch in BNet_config["n_chs_of_modalities"] - ] - # self.cgc = scitex.nn.ChannelGainChanger(N_VIRTUAL_CHS) - - # MNet_config["n_chs"] = BNet_config["n_virtual_chs"] # BNet_config["n_chs"] # override - - n_chs = BNet_config["n_virtual_chs"] - self.blk1 = scitex.nn.ResNetBasicBlock(n_chs, n_chs) - self.blk2 = scitex.nn.ResNetBasicBlock(int(n_chs / 2**1), int(n_chs / 2**1)) - self.blk3 = scitex.nn.ResNetBasicBlock(int(n_chs / 2**2), int(n_chs / 2**2)) - self.blk4 = scitex.nn.ResNetBasicBlock(int(n_chs / 2**3), int(n_chs / 2**3)) - self.blk5 = scitex.nn.ResNetBasicBlock(1, 1) - self.blk6 = scitex.nn.ResNetBasicBlock(1, 1) - self.blk7 = scitex.nn.ResNetBasicBlock(1, 1) - - # self.MNet = scitex.nn.MNet_1000(MNet_config) - - # self.fcs = nn.ModuleList( - # [ - # nn.Sequential( - # # nn.Linear(N_FC_IN, config["n_fc1"]), - # nn.Mish(), - # nn.Dropout(BNet_config["d_ratio1"]), - # nn.Linear(BNet_config["n_fc1"], BNet_config["n_fc2"]), - # nn.Mish(), - # nn.Dropout(BNet_config["d_ratio2"]), - # nn.Linear(BNet_config["n_fc2"], BNet_config["n_classes_of_modalities"][i_head]), - # ) - # for i_head, _ in enumerate(range(len(BNet_config["n_chs_of_modalities"]))) - # ] - # ) - - @staticmethod - def _znorm_along_the_last_dim(x): - return (x - x.mean(dim=-1, keepdims=True)) / x.std(dim=-1, keepdims=True) - - def forward(self, x, i_head): - x = self._znorm_along_the_last_dim(x) - # x = self.sc(x) - x = self.dc(x) - x = self.fgc(x) - x = self.cgcs[i_head](x) - x = self.heads[i_head](x) - - x = self.blk1(x) - x = F.avg_pool1d(x.transpose(1, 2), kernel_size=2).transpose(1, 2) - x = F.avg_pool1d(x, kernel_size=2) - x = self.blk2(x) - x = F.avg_pool1d(x.transpose(1, 2), kernel_size=2).transpose(1, 2) - x = F.avg_pool1d(x, kernel_size=2) - x = self.blk3(x) - x = F.avg_pool1d(x.transpose(1, 2), kernel_size=2).transpose(1, 2) - x = F.avg_pool1d(x, kernel_size=2) - x = self.blk4(x) - x = F.avg_pool1d(x.transpose(1, 2), kernel_size=2).transpose(1, 2) - x = F.avg_pool1d(x, kernel_size=2) - - x = self.blk5(x) - x = F.avg_pool1d(x, kernel_size=2) - x = self.blk6(x) - x = F.avg_pool1d(x, kernel_size=2) - x = self.blk7(x) - x = F.avg_pool1d(x, kernel_size=2) - - import ipdb - - ipdb.set_trace() - - # x = self.cgc(x) - x = self.MNet.forward_bb(x) - x = self.fcs[i_head](x) - return x - - -# BNet_config = { -# "n_chs": 32, -# "n_bands": 6, -# "SAMP_RATE": 1000, -# } -BNet_config = { - "n_bands": 6, - "n_virtual_chs": 16, - "SAMP_RATE": 250, - "n_fc1": 1024, - "d_ratio1": 0.85, - "n_fc2": 256, - "d_ratio2": 0.85, -} - - -if __name__ == "__main__": - ## Demo data - # MEG - BS, N_CHS, SEQ_LEN = 16, 160, 1024 - x_MEG = torch.rand(BS, N_CHS, SEQ_LEN).cuda() - # EEG - BS, N_CHS, SEQ_LEN = 16, 19, 1024 - x_EEG = torch.rand(BS, N_CHS, SEQ_LEN).cuda() - - # m = scitex.nn.ResNetBasicBlock(19, 19).cuda() - # m(x_EEG) - # model = MNetBackBorn(scitex.nn.MNet_config).cuda() - # model(x_MEG) - # Model - BNet_config["n_chs_of_modalities"] = [160, 19] - BNet_config["n_classes_of_modalities"] = [2, 4] - model = BNet(BNet_config, scitex.nn.MNet_config).cuda() - - # MEG - y = model(x_MEG, 0) - y = model(x_EEG, 1) - - # # EEG - # y = model(x_EEG) - - y.sum().backward() diff --git a/src/scitex/nn/_ChannelGainChanger.py b/src/scitex/nn/_ChannelGainChanger.py deleted file mode 100755 index 550e09255..000000000 --- a/src/scitex/nn/_ChannelGainChanger.py +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2023-04-23 11:02:45 (ywatanabe)" - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchsummary import summary - -import scitex - - -class ChannelGainChanger(nn.Module): - def __init__( - self, - n_chs, - ): - super().__init__() - self.n_chs = n_chs - - def forward(self, x): - """x: [batch_size, n_chs, seq_len]""" - if self.training: - ch_gains = ( - torch.rand(self.n_chs).unsqueeze(0).unsqueeze(-1).to(x.device) + 0.5 - ) - ch_gains = F.softmax(ch_gains, dim=1) - x *= ch_gains - - return x - - -if __name__ == "__main__": - ## Demo data - bs, n_chs, seq_len = 16, 360, 1000 - x = torch.rand(bs, n_chs, seq_len) - - cgc = ChGainChanger(n_chs) - print(cgc(x).shape) # [16, 19, 1000] - - # sb = SubjectBlock(n_chs=n_chs) - # print(sb(x, s).shape) # [16, 270, 1000] - - # summary(sb, x, s) diff --git a/src/scitex/nn/_DropoutChannels.py b/src/scitex/nn/_DropoutChannels.py deleted file mode 100755 index 33968605e..000000000 --- a/src/scitex/nn/_DropoutChannels.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2023-05-04 21:50:22 (ywatanabe)" - -import random - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchsummary import summary - -import scitex - - -class DropoutChannels(nn.Module): - def __init__(self, dropout=0.5): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - - def forward(self, x): - """x: [batch_size, n_chs, seq_len]""" - if self.training: - orig_chs = torch.arange(x.shape[1]) - - indi_orig = self.dropout(torch.ones(x.shape[1])).bool() - chs_to_shuffle = orig_chs[~indi_orig] - - x[:, chs_to_shuffle] = torch.randn(x[:, chs_to_shuffle].shape).to(x.device) - - # rand_chs = random.sample(list(np.array(chs_to_shuffle)), len(chs_to_shuffle)) - - # swapped_chs = orig_chs.clone() - # swapped_chs[~indi_orig] = torch.LongTensor(rand_chs) - - # x = x[:, swapped_chs.long(), :] - - return x - - -if __name__ == "__main__": - ## Demo data - bs, n_chs, seq_len = 16, 360, 1000 - x = torch.rand(bs, n_chs, seq_len) - - dc = DropoutChannels(dropout=0.1) - print(dc(x).shape) # [16, 19, 1000] - - # sb = SubjectBlock(n_chs=n_chs) - # print(sb(x, s).shape) # [16, 270, 1000] - - # summary(sb, x, s) diff --git a/src/scitex/nn/_Filters.py b/src/scitex/nn/_Filters.py deleted file mode 100755 index 8659e79e6..000000000 --- a/src/scitex/nn/_Filters.py +++ /dev/null @@ -1,491 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-05-28 17:05:26 (ywatanabe)" -# File: /ssh:sp:/home/ywatanabe/proj/scitex_repo/src/scitex/nn/_Filters.py -# ---------------------------------------- -import os - -__FILE__ = "./src/scitex/nn/_Filters.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- - -# Time-stamp: "2024-11-26 22:23:40 (ywatanabe)" - -import numpy as np - -THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/nn/_Filters.py" - -""" -Implements various neural network filter layers: - - BaseFilter1D: Abstract base class for 1D filters - - BandPassFilter: Implements bandpass filtering - - BandStopFilter: Implements bandstop filtering - - LowPassFilter: Implements lowpass filtering - - HighPassFilter: Implements highpass filtering - - GaussianFilter: Implements Gaussian smoothing - - DifferentiableBandPassFilter: Implements learnable bandpass filtering -""" - -# Imports -import sys -from abc import abstractmethod - -import matplotlib.pyplot as plt -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -from scitex.dsp.utils import build_bandpass_filters, init_bandpass_filters -from scitex.dsp.utils._ensure_3d import ensure_3d -from scitex.dsp.utils._ensure_even_len import ensure_even_len -from scitex.dsp.utils._zero_pad import zero_pad -from scitex.dsp.utils.filter import design_filter -from scitex.gen._to_even import to_even - - -class BaseFilter1D(nn.Module): - def __init__(self, fp16=False, in_place=False): - super().__init__() - self.fp16 = fp16 - self.in_place = in_place - # self.kernels = None - - @abstractmethod - def init_kernels( - self, - ): - """ - Abstract method to initialize filter kernels. - Must be implemented by subclasses. - """ - pass - - def forward(self, x, t=None, edge_len=0): - """Apply the filter to input signal x with shape: (batch_size, n_chs, seq_len)""" - - # Shape check - if self.fp16: - x = x.half() - - x = ensure_3d(x) - batch_size, n_chs, seq_len = x.shape - - # Kernel Check - if self.kernels is None: - raise ValueError("Filter kernels has not been initialized.") - - # Filtering - x = self.flip_extend(x, self.kernel_size // 2) - x = self.batch_conv(x, self.kernels, padding=0) - x = x[..., :seq_len] - - assert x.shape == ( - batch_size, - n_chs, - len(self.kernels), - seq_len, - ), f"The shape of the filtered signal ({x.shape}) does not match the expected shape: ({batch_size}, {n_chs}, {len(self.kernels)}, {seq_len})." - - # Edge remove - x = self.remove_edges(x, edge_len) - - if t is None: - return x - else: - t = self.remove_edges(t, edge_len) - return x, t - - @property - def kernel_size( - self, - ): - ks = self.kernels.shape[-1] - # if not ks % 2 == 0: - # raise ValueError("Kernel size should be an even number.") - return ks - - @staticmethod - def flip_extend(x, extension_length): - first_segment = x[:, :, :extension_length].flip(dims=[-1]) - last_segment = x[:, :, -extension_length:].flip(dims=[-1]) - return torch.cat([first_segment, x, last_segment], dim=-1) - - @staticmethod - def batch_conv(x, kernels, padding="same"): - """ - x: (batch_size, n_chs, seq_len) - kernels: (n_kernels, seq_len_filt) - """ - assert x.ndim == 3 - assert kernels.ndim == 2 - batch_size, n_chs, n_time = x.shape - x = x.reshape(-1, x.shape[-1]).unsqueeze(1) - kernels = kernels.unsqueeze(1) # add the channel dimension - n_kernels = len(kernels) - filted = F.conv1d(x, kernels.type_as(x), padding=padding) - return filted.reshape(batch_size, n_chs, n_kernels, -1) - - @staticmethod - def remove_edges(x, edge_len): - edge_len = x.shape[-1] // 8 if edge_len == "auto" else edge_len - - if 0 < edge_len: - return x[..., edge_len:-edge_len] - else: - return x - - -class BandPassFilter(BaseFilter1D): - def __init__(self, bands, fs, seq_len, fp16=False): - super().__init__(fp16=fp16) - - self.fp16 = fp16 - - # Ensures bands shape - assert bands.ndim == 2 - - # Check bands definitions - nyq = fs / 2.0 - # Convert bands to tensor if it's a numpy array - if isinstance(bands, np.ndarray): - bands = torch.tensor(bands) - bands = torch.clip(bands, 0.1, nyq - 1) - for ll, hh in bands: - assert 0 < ll - assert ll < hh - assert hh < nyq - - # Prepare kernels - kernels = self.init_kernels(seq_len, fs, bands) - if fp16: - kernels = kernels.half() - self.register_buffer( - "kernels", - kernels, - ) - - @staticmethod - def init_kernels(seq_len, fs, bands): - # Convert seq_len and fs to numpy arrays for design_filter (expects numpy_fn) - seq_len_array = np.array([seq_len]) - fs_array = np.array([fs]) - filters = [ - design_filter( - seq_len_array, - fs_array, - low_hz=ll, - high_hz=hh, - is_bandstop=False, - ) - for ll, hh in bands - ] - - # Convert filters list to tensors for zero_pad - filters_tensors = [ - torch.tensor(f) if not isinstance(f, torch.Tensor) else f for f in filters - ] - - kernels = zero_pad(filters_tensors) - kernels = ensure_even_len(kernels) - if not isinstance(kernels, torch.Tensor): - kernels = torch.tensor(kernels) - kernels = kernels.clone().detach() - # kernels = kernels.clone().detach().requires_grad_(True) - return kernels - - -# /home/ywatanabe/proj/scitex/src/scitex/nn/_Filters.py:155: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). -# kernels = torch.tensor(kernels).clone().detach() - - -class BandStopFilter(BaseFilter1D): - def __init__(self, bands, fs, seq_len): - super().__init__() - - # Ensures bands shape - assert bands.ndim == 2 - - # Check bands definitions - nyq = fs / 2.0 - bands = np.clip(bands, 0.1, nyq - 1) - for ll, hh in bands: - assert 0 < ll - assert ll < hh - assert hh < nyq - - self.register_buffer("kernels", self.init_kernels(seq_len, fs, bands)) - - @staticmethod - def init_kernels(seq_len, fs, bands): - # Convert to numpy arrays for design_filter - seq_len_array = np.array([seq_len]) - fs_array = np.array([fs]) - filters = [ - design_filter( - seq_len_array, fs_array, low_hz=ll, high_hz=hh, is_bandstop=True - ) - for ll, hh in bands - ] - # Convert filters list to tensors for zero_pad - filters_tensors = [ - torch.tensor(f) if not isinstance(f, torch.Tensor) else f for f in filters - ] - kernels = zero_pad(filters_tensors) - kernels = ensure_even_len(kernels) - if not isinstance(kernels, torch.Tensor): - kernels = torch.tensor(kernels) - return kernels - - -class LowPassFilter(BaseFilter1D): - def __init__(self, cutoffs_hz, fs, seq_len): - super().__init__() - - # Ensures bands shape - assert cutoffs_hz.ndim == 1 - - # Check bands definitions - nyq = fs / 2.0 - bands = np.clip(cutoffs_hz, 0.1, nyq - 1) - for cc in cutoffs_hz: - assert 0 < cc - assert cc < nyq - - self.register_buffer("kernels", self.init_kernels(seq_len, fs, cutoffs_hz)) - - @staticmethod - def init_kernels(seq_len, fs, cutoffs_hz): - # Convert to numpy arrays for design_filter - seq_len_array = np.array([seq_len]) - fs_array = np.array([fs]) - filters = [ - design_filter( - seq_len_array, fs_array, low_hz=None, high_hz=cc, is_bandstop=False - ) - for cc in cutoffs_hz - ] - # Convert filters list to tensors for zero_pad - filters_tensors = [ - torch.tensor(f) if not isinstance(f, torch.Tensor) else f for f in filters - ] - kernels = zero_pad(filters_tensors) - kernels = ensure_even_len(kernels) - if not isinstance(kernels, torch.Tensor): - kernels = torch.tensor(kernels) - return kernels - - -class HighPassFilter(BaseFilter1D): - def __init__(self, cutoffs_hz, fs, seq_len): - super().__init__() - - # Ensures bands shape - assert cutoffs_hz.ndim == 1 - - # Check bands definitions - nyq = fs / 2.0 - bands = np.clip(cutoffs_hz, 0.1, nyq - 1) - for cc in cutoffs_hz: - assert 0 < cc - assert cc < nyq - - self.register_buffer("kernels", self.init_kernels(seq_len, fs, cutoffs_hz)) - - @staticmethod - def init_kernels(seq_len, fs, cutoffs_hz): - # Convert to numpy arrays for design_filter - seq_len_array = np.array([seq_len]) - fs_array = np.array([fs]) - filters = [ - design_filter( - seq_len_array, fs_array, low_hz=cc, high_hz=None, is_bandstop=False - ) - for cc in cutoffs_hz - ] - # Convert filters list to tensors for zero_pad - filters_tensors = [ - torch.tensor(f) if not isinstance(f, torch.Tensor) else f for f in filters - ] - kernels = zero_pad(filters_tensors) - kernels = ensure_even_len(kernels) - if not isinstance(kernels, torch.Tensor): - kernels = torch.tensor(kernels) - return kernels - - -class GaussianFilter(BaseFilter1D): - def __init__(self, sigma): - super().__init__() - self.sigma = to_even(sigma) - self.register_buffer("kernels", self.init_kernels(sigma)) - - @staticmethod - def init_kernels(sigma): - kernel_size = sigma * 6 # +/- 3SD - kernel_range = torch.arange(0, kernel_size) - kernel_size // 2 - kernel = torch.exp(-0.5 * (kernel_range / sigma) ** 2) - kernel /= kernel.sum() - kernels = kernel.unsqueeze(0) # n_filters = 1 - kernels = ensure_even_len(kernels) - return torch.tensor(kernels) - - -class DifferentiableBandPassFilter(BaseFilter1D): - def __init__( - self, - sig_len, - fs, - pha_low_hz=2, - pha_high_hz=20, - pha_n_bands=30, - amp_low_hz=80, - amp_high_hz=160, - amp_n_bands=50, - cycle=3, - fp16=False, - ): - super().__init__(fp16=fp16) - - # Attributes - self.pha_low_hz = pha_low_hz - self.pha_high_hz = pha_high_hz - self.amp_low_hz = amp_low_hz - self.amp_high_hz = amp_high_hz - self.sig_len = sig_len - self.fs = fs - self.cycle = cycle - self.fp16 = fp16 - - # Check bands definitions - nyq = fs / 2.0 - pha_high_hz = torch.tensor(pha_high_hz).clip(0.1, nyq - 1) - pha_low_hz = torch.tensor(pha_low_hz).clip(0.1, pha_high_hz - 1) - amp_high_hz = torch.tensor(amp_high_hz).clip(0.1, nyq - 1) - amp_low_hz = torch.tensor(amp_low_hz).clip(0.1, amp_high_hz - 1) - - assert pha_low_hz < pha_high_hz < nyq - assert amp_low_hz < amp_high_hz < nyq - - # Prepare kernels - self.init_kernels = init_bandpass_filters - self.build_bandpass_filters = build_bandpass_filters - kernels, self.pha_mids, self.amp_mids = self.init_kernels( - sig_len=sig_len, - fs=fs, - pha_low_hz=pha_low_hz, - pha_high_hz=pha_high_hz, - pha_n_bands=pha_n_bands, - amp_low_hz=amp_low_hz, - amp_high_hz=amp_high_hz, - amp_n_bands=amp_n_bands, - cycle=cycle, - ) - - self.register_buffer( - "kernels", - kernels, - ) - # self.register_buffer("pha_mids", pha_mids) - # self.register_buffer("amp_mids", amp_mids) - # self.pha_mids = nn.Parameter(pha_mids.detach()) - # self.amp_mids = nn.Parameter(amp_mids.detach()) - - if fp16: - self.kernels = self.kernels.half() - # self.pha_mids = self.pha_mids.half() - # self.amp_mids = self.amp_mids.half() - - def forward(self, x, t=None, edge_len=0): - # Constrains the parameter spaces - torch.clip(self.pha_mids, self.pha_low_hz, self.pha_high_hz) - torch.clip(self.amp_mids, self.amp_low_hz, self.amp_high_hz) - - self.kernels = self.build_bandpass_filters( - self.sig_len, self.fs, self.pha_mids, self.amp_mids, self.cycle - ) - return super().forward(x=x, t=t, edge_len=edge_len) - - -if __name__ == "__main__": - import scitex - - # Start - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start( - sys, plt, fig_scale=5 - ) - - xx, tt, fs = scitex.dsp.demo_sig(sig_type="chirp", fs=1024) - xx = torch.tensor(xx).cuda() - # bands = np.array([[2, 3], [3, 4]]) - # BandPassFilter(bands, fs, xx.shape) - m = DifferentiableBandPassFilter(xx.shape[-1], fs).cuda() - - scitex.ai.utils.check_params(m) - # {'pha_mids': (torch.Size([30]), 'Learnable'), - # 'amp_mids': (torch.Size([50]), 'Learnable')} - - xf = m(xx) # (8, 19, 80, 2048) - - xf.sum().backward() # OK, differentiable - - m.pha_mids - # Parameter containing: - # tensor([ 2.0000, 2.6207, 3.2414, 3.8621, 4.4828, 5.1034, 5.7241, 6.3448, - # 6.9655, 7.5862, 8.2069, 8.8276, 9.4483, 10.0690, 10.6897, 11.3103, - # 11.9310, 12.5517, 13.1724, 13.7931, 14.4138, 15.0345, 15.6552, 16.2759, - # 16.8966, 17.5172, 18.1379, 18.7586, 19.3793, 20.0000], - # requires_grad=True) - m.amp_mids - # Parameter containing: - # tensor([ 80.0000, 81.6327, 83.2653, 84.8980, 86.5306, 88.1633, 89.7959, - # 91.4286, 93.0612, 94.6939, 96.3265, 97.9592, 99.5918, 101.2245, - # 102.8571, 104.4898, 106.1225, 107.7551, 109.3878, 111.0204, 112.6531, - # 114.2857, 115.9184, 117.5510, 119.1837, 120.8163, 122.4490, 124.0816, - # 125.7143, 127.3469, 128.9796, 130.6122, 132.2449, 133.8775, 135.5102, - # 137.1429, 138.7755, 140.4082, 142.0408, 143.6735, 145.3061, 146.9388, - # 148.5714, 150.2041, 151.8367, 153.4694, 155.1020, 156.7347, 158.3673, - # 160.0000], requires_grad=True) - - # PSD - bands = torch.hstack([m.pha_mids, m.amp_mids]) - - # Plots PSD - # matplotlib.use("TkAgg") - fig, axes = scitex.plt.subplots(nrows=1 + len(bands), ncols=2) - - psd, ff = scitex.dsp.psd(xx, fs) # Orig - axes[0, 0].plot(tt, xx[0, 0].detach().cpu().numpy(), label="orig") - axes[0, 1].plot( - ff.detach().cpu().numpy(), - psd[0, 0].detach().cpu().numpy(), - label="orig", - ) - - for i_filt in range(len(bands)): - mid_hz = int(bands[i_filt].item()) - psd_f, ff_f = scitex.dsp.psd(xf[:, :, i_filt, :], fs) - axes[i_filt + 1, 0].plot( - tt, - xf[0, 0, i_filt].detach().cpu().numpy(), - label=f"filted at {mid_hz} Hz", - ) - axes[i_filt + 1, 1].plot( - ff_f.detach().cpu().numpy(), - psd_f[0, 0].detach().cpu().numpy(), - label=f"filted at {mid_hz} Hz", - ) - for ax in axes.ravel(): - ax.legend(loc="upper left") - - scitex.io.save(fig, "traces.png") - # plt.show() - - # Close - scitex.session.close(CONFIG) - -""" -/home/ywatanabe/proj/entrance/scitex/dsp/nn/_Filters.py -""" - -# EOF diff --git a/src/scitex/nn/_FreqGainChanger.py b/src/scitex/nn/_FreqGainChanger.py deleted file mode 100755 index db8ce20cd..000000000 --- a/src/scitex/nn/_FreqGainChanger.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2023-04-23 11:02:34 (ywatanabe)" - -import julius -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchsummary import summary - -import scitex - -# BANDS_LIM_HZ_DICT = { -# "delta": [0.5, 4], -# "theta": [4, 8], -# "lalpha": [8, 10], -# "halpha": [10, 13], -# "beta": [13, 32], -# "gamma": [32, 75], -# } - - -# class FreqDropout(nn.Module): -# def __init__(self, n_bands, samp_rate, dropout_ratio=0.5): -# super().__init__() -# self.dropout = nn.Dropout(p=0.5) -# self.n_bands = n_bands -# self.samp_rate = samp_rate -# # self. -# self.register_buffer("ones", torch.ones(self.n_bands)) - -# def forward(self, x): -# """x: [batch_size, n_chs, seq_len]""" -# x = julius.bands.split_bands(x, self.samp_rate, n_bands=self.n_bands) - -# gains_orig = x.reshape(len(x), -1).abs().sum(axis=-1) -# sum_gains_orig = gains_orig.sum() - -# # use_freqs = self.dropout(torch.ones(self.n_bands)).bool().long() -# use_freqs = self.dropout(self.ones) / 2 # .bool().long() - -# gains = gains_orig * use_freqs -# sum_gains = gains.sum() -# gain_ratio = sum_gains / sum_gains_orig - - -# x *= use_freqs.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) -# x /= gain_ratio -# x = x.sum(axis=0) - -# return x - - -class FreqGainChanger(nn.Module): - def __init__(self, n_bands, samp_rate, dropout_ratio=0.5): - super().__init__() - self.dropout = nn.Dropout(p=0.5) - self.n_bands = n_bands - self.samp_rate = samp_rate - # self.register_buffer("ones", torch.ones(self.n_bands)) - - def forward(self, x): - """x: [batch_size, n_chs, seq_len]""" - if self.training: - x = julius.bands.split_bands(x, self.samp_rate, n_bands=self.n_bands) - freq_gains = ( - torch.rand(self.n_bands) - .unsqueeze(-1) - .unsqueeze(-1) - .unsqueeze(-1) - .to(x.device) - + 0.5 - ) - freq_gains = F.softmax(freq_gains, dim=0) - x = (x * freq_gains).sum(axis=0) - - return x - # import ipdb; ipdb.set_trace() - - # gains_orig = x.reshape(len(x), -1).abs().sum(axis=-1) - # sum_gains_orig = gains_orig.sum() - - # # use_freqs = self.dropout(torch.ones(self.n_bands)).bool().long() - # use_freqs = self.dropout(self.ones) / 2 # .bool().long() - - # gains = gains_orig * use_freqs - # sum_gains = gains.sum() - # gain_ratio = sum_gains / sum_gains_orig - - # x *= use_freqs.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) - # x /= gain_ratio - # x = x.sum(axis=0) - - # return x - - -if __name__ == "__main__": - # Parameters - N_BANDS = 10 - SAMP_RATE = 1000 - BS, N_CHS, SEQ_LEN = 16, 360, 1000 - - # Demo data - x = torch.rand(BS, N_CHS, SEQ_LEN).cuda() - - # Feedforward - fgc = FreqGainChanger(N_BANDS, SAMP_RATE).cuda() - # fd.eval() - y = fgc(x) - y.sum().backward() diff --git a/src/scitex/nn/_GaussianFilter.py b/src/scitex/nn/_GaussianFilter.py deleted file mode 100755 index 4ce585ae1..000000000 --- a/src/scitex/nn/_GaussianFilter.py +++ /dev/null @@ -1,48 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-04-01 18:14:44 (ywatanabe)" - -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchaudio.transforms as T - - -class GaussianFilter(nn.Module): - def __init__(self, radius, sigma=None): - super().__init__() - if sigma is None: - sigma = radius / 2 - self.radius = radius - self.register_buffer("kernel", self.gen_kernel_1d(radius, sigma=sigma)) - - @staticmethod - def gen_kernel_1d(radius, sigma=None): - if sigma is None: - sigma = radius / 2 - - kernel_size = 2 * radius + 1 - x = torch.arange(kernel_size).float() - radius - - kernel = torch.exp(-0.5 * (x / sigma) ** 2) - kernel = kernel / (sigma * math.sqrt(2 * math.pi)) - kernel = kernel / torch.sum(kernel) - - return kernel.unsqueeze(0).unsqueeze(0) - - def forward(self, x): - """x.shape: (batch_size, n_chs, seq_len)""" - - if x.ndim == 1: - x = x.unsqueeze(0).unsqueeze(0) - elif x.ndim == 2: - x = x.unsqueeze(1) - - channels = x.size(1) - kernel = self.kernel.expand(channels, 1, -1).to(x.device).to(x.dtype) - - return torch.nn.functional.conv1d( - x, kernel, padding=self.radius, groups=channels - ) diff --git a/src/scitex/nn/_Hilbert.py b/src/scitex/nn/_Hilbert.py deleted file mode 100755 index 14209b46a..000000000 --- a/src/scitex/nn/_Hilbert.py +++ /dev/null @@ -1,111 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Timestamp: "2025-04-10 12:46:06 (ywatanabe)" -# File: /home/ywatanabe/proj/scitex_repo/src/scitex/nn/_Hilbert.py -# ---------------------------------------- -import os - -__FILE__ = "/home/ywatanabe/proj/scitex_repo/src/scitex/nn/_Hilbert.py" -__DIR__ = os.path.dirname(__FILE__) -# ---------------------------------------- -#!/usr/bin/env python - -import torch # 1.7.1 -import torch.nn as nn -from torch.fft import fft, ifft - - -class Hilbert(nn.Module): - def __init__(self, seq_len, dim=-1, fp16=False, in_place=False): - super().__init__() - self.dim = dim - self.fp16 = fp16 - self.in_place = in_place - self.n = seq_len - f = torch.cat( - [ - torch.arange(0, (self.n - 1) // 2 + 1) / float(self.n), - torch.arange(-(self.n // 2), 0) / float(self.n), - ] - ) - self.register_buffer("f", f) - - def hilbert_transform(self, x): - # n = x.shape[self.dim] - - # Create frequency dim - # f = torch.cat( - # [ - # torch.arange(0, (n - 1) // 2 + 1, device=x.device) / float(n), - # torch.arange(-(n // 2), 0, device=x.device) / float(n), - # ] - # ) - - orig_dtype = x.dtype - x = x.float() - xf = fft(x, n=self.n, dim=self.dim) - x = x.to(orig_dtype) - - # Create step function - steepness = 50 # This value can be adjusted - u = torch.sigmoid( - steepness * self.f.type_as(x) - ) # Soft step function for differentiability - - transformed = ifft(xf * 2 * u, dim=self.dim) - - return transformed - - def forward(self, x): - if self.fp16: - x = x.half() - - if not self.in_place: - x = x.clone() # Ensure that we do not modify the input in-place - - x_comp = self.hilbert_transform(x) - - pha = torch.atan2(x_comp.imag, x_comp.real) - amp = x_comp.abs() - - assert x.shape == pha.shape == amp.shape - - out = torch.cat( - [ - pha.unsqueeze(-1), - amp.unsqueeze(-1), - ], - dim=-1, - ) - - # if self.fp16: - # out = ( - # out.float() - # ) - # # Optionally cast back to float for stability in subsequent operations - - if self.fp16: - out = out.float() - - return out - - -if __name__ == "__main__": - import scitex - - xx, tt, fs = scitex.dsp.demo_sig() - xx = torch.tensor(xx) - - # Parameters - device = "cuda" - fp16 = True - in_place = True - - # Initialization - m = Hilbert(xx.shape[-1], fp16=fp16, in_place=in_place).to(device) - - # Calculation - xx = xx.to(device) - y = m(xx) - -# EOF diff --git a/src/scitex/nn/_MNet_1000.py b/src/scitex/nn/_MNet_1000.py deleted file mode 100755 index 096784603..000000000 --- a/src/scitex/nn/_MNet_1000.py +++ /dev/null @@ -1,158 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2023-05-04 16:54:55 (ywatanabe)" - -#!/usr/bin/env python - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchsummary import summary - -import scitex - -MNet_config = { - "classes": ["class1", "class2"], - "n_chs": 270, - "n_fc1": 1024, - "d_ratio1": 0.85, - "n_fc2": 256, - "d_ratio2": 0.85, -} - - -class MNet1000(nn.Module): - def __init__(self, config): - super().__init__() - - # basic - self.config = config - # fc - N_FC_IN = 15950 - - # conv - self.backborn = nn.Sequential( - *[ - nn.Conv2d(1, 40, kernel_size=(config["n_chs"], 4)), - nn.Mish(), - nn.Conv2d(40, 40, kernel_size=(1, 4)), - nn.BatchNorm2d(40), - nn.MaxPool2d((1, 5)), - nn.Mish(), - SwapLayer(), - nn.Conv2d(1, 50, kernel_size=(8, 12)), - nn.BatchNorm2d(50), - nn.MaxPool2d((3, 3)), - nn.Mish(), - nn.Conv2d(50, 50, kernel_size=(1, 5)), - nn.BatchNorm2d(50), - nn.MaxPool2d((1, 2)), - nn.Mish(), - ReshapeLayer(), - nn.Linear(N_FC_IN, config["n_fc1"]), - ] - ) - - # # conv - # self.conv1 = nn.Conv2d(1, 40, kernel_size=(config["n_chs"], 4)) - # self.act1 = nn.Mish() - - # self.conv2 = nn.Conv2d(40, 40, kernel_size=(1, 4)) - # self.bn2 = nn.BatchNorm2d(40) - # self.pool2 = nn.MaxPool2d((1, 5)) - # self.act2 = nn.Mish() - - # self.swap = SwapLayer() - - # self.conv3 = nn.Conv2d(1, 50, kernel_size=(8, 12)) - # self.bn3 = nn.BatchNorm2d(50) - # self.pool3 = nn.MaxPool2d((3, 3)) - # self.act3 = nn.Mish() - - # self.conv4 = nn.Conv2d(50, 50, kernel_size=(1, 5)) - # self.bn4 = nn.BatchNorm2d(50) - # self.pool4 = nn.MaxPool2d((1, 2)) - # self.act4 = nn.Mish() - - self.fc = nn.Sequential( - # nn.Linear(N_FC_IN, config["n_fc1"]), - nn.Mish(), - nn.Dropout(config["d_ratio1"]), - nn.Linear(config["n_fc1"], config["n_fc2"]), - nn.Mish(), - nn.Dropout(config["d_ratio2"]), - nn.Linear(config["n_fc2"], len(config["classes"])), - ) - - @staticmethod - def _reshape_input(x, n_chs): - """ - (batch, channel, time_length) -> (batch, channel, time_length, new_axis) - """ - if x.ndim == 3: - x = x.unsqueeze(-1) - if x.shape[2] == n_chs: - x = x.transpose(1, 2) - x = x.transpose(1, 3).transpose(2, 3) - return x - - @staticmethod - def _znorm_along_the_last_dim(x): - return (x - x.mean(dim=-1, keepdims=True)) / x.std(dim=-1, keepdims=True) - - def forward(self, x): - # # time-wise normalization - # x = self._znorm_along_the_last_dim(x) - # x = self._reshape_input(x, self.config["n_chs"]) - - # x = self.backborn(x) - x = self.forward_bb(x) - - # x = x.reshape(len(x), -1) - - x = self.fc(x) - - return x - - def forward_bb(self, x): - # time-wise normalization - x = self._znorm_along_the_last_dim(x) - x = self._reshape_input(x, self.config["n_chs"]) - x = self.backborn(x) - return x - - -class SwapLayer(nn.Module): - def __init__( - self, - ): - super().__init__() - - def forward(self, x): - return x.transpose(1, 2) - - -class ReshapeLayer(nn.Module): - def __init__( - self, - ): - super().__init__() - - def forward(self, x): - return x.reshape(len(x), -1) - - -if __name__ == "__main__": - ## Demo data - BS, N_CHS, SEQ_LEN = 16, 270, 1000 - x = torch.rand(BS, N_CHS, SEQ_LEN).cuda() - - ## Config for the model - model = MNet_1000(MNet_config).cuda() - - y = model(x) - summary(model, x) - print(y.shape) - -# Backward compatibility -MNet_1000 = MNet1000 # Deprecated: use MNet1000 instead diff --git a/src/scitex/nn/_ModulationIndex.py b/src/scitex/nn/_ModulationIndex.py deleted file mode 100755 index 1e4fc2275..000000000 --- a/src/scitex/nn/_ModulationIndex.py +++ /dev/null @@ -1,224 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-11-04 02:08:01 (ywatanabe)" -# File: ./scitex_repo/src/scitex/nn/_ModulationIndex.py - -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-10-15 14:12:55 (ywatanabe)" - -""" -This script defines the ModulationIndex module. -""" - -# Imports -import sys -import warnings - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - - -# Functions -class ModulationIndex(nn.Module): - def __init__(self, n_bins=18, fp16=False, amp_prob=False): - super(ModulationIndex, self).__init__() - self.n_bins = n_bins - self.fp16 = fp16 - self.register_buffer( - "pha_bin_cutoffs", torch.linspace(-np.pi, np.pi, n_bins + 1) - ) - - self.amp_prob = amp_prob - - @property - def pha_bin_centers( - self, - ): - return ( - ((self.pha_bin_cutoffs[1:] + self.pha_bin_cutoffs[:-1]) / 2) - .detach() - .cpu() - .numpy() - ) - - def forward(self, pha, amp, epsilon=1e-9): - """ - Compute the Modulation Index based on phase (pha) and amplitude (amp) tensors. - - Parameters: - - pha (torch.Tensor): Tensor of phase values with shape - (batch_size, n_channels, n_freqs_pha, n_segments, sequence_length). - - amp (torch.Tensor): Tensor of amplitude values with a similar shape as pha. - (batch_size, n_channels, n_freqs_amp, n_segments, sequence_length). - - Returns: - - MI (torch.Tensor): The Modulation Index for each batch and channel. - """ - assert pha.ndim == amp.ndim == 5 - - if self.fp16: - pha, amp = pha.half().contiguous(), amp.half().contiguous() - else: - pha, amp = pha.float().contiguous(), amp.float().contiguous() - - device = pha.device - - pha_masks = self._phase_to_masks(pha, self.pha_bin_cutoffs.to(device)) - # (batch_size, n_channels, n_freqs_pha, n_segments, sequence_length, n_bins) - - # Expands amp and masks to utilize broadcasting - # i_batch = 0 - # i_chs = 1 - i_freqs_pha = 2 - i_freqs_amp = 3 - # i_segments = 4 - i_time = 5 - i_bins = 6 - - # Coupling - pha_masks = pha_masks.unsqueeze(i_freqs_amp) - amp = amp.unsqueeze(i_freqs_pha).unsqueeze(i_bins) - - amp_bins = pha_masks * amp # this is the most memory-consuming process - - # # Batch processing to reduce maximum VRAM occupancy - # pha_masks = self.dh_pha.fit(pha_masks, keepdims=[2, 3, 5, 6]) - # amp = self.dh_amp.fit(amp, keepdims=[2, 3, 5, 6]) - # n_chunks = len(pha_masks) // self.chunk_size - # amp_bins = [] - # for i_chunk in range(n_chunks): - # start = i_chunk * self.chunk_size - # end = (i_chunk + 1) * self.chunk_size - # _amp_bins = pha_masks[start:end] * amp[start:end] - # amp_bins.append(_amp_bins.cpu()) - # amp_bins = torch.cat(amp_bins) - # amp_bins = self.dh_pha.unfit(amp_bins) - # pha_masks = self.dh_pha.unfit(pha_masks) - # Takes mean amplitude in each bin - amp_sums = amp_bins.sum(dim=i_time, keepdims=True).to(device) - counts = pha_masks.sum(dim=i_time, keepdims=True) - amp_means = amp_sums / (counts + epsilon) - - amp_probs = amp_means / (amp_means.sum(dim=-1, keepdims=True) + epsilon) - - if self.amp_prob: - return amp_probs.detach().cpu() - - """ - matplotlib.use("TkAgg") - fig, ax = scitex.plt.subplots(subplot_kw={'polar': True}) - yy = amp_probs[0, 0, 0, 0, 0, 0, :].detach().cpu().numpy() - xx = ((self.pha_bin_cutoffs[1:] + self.pha_bin_cutoffs[:-1]) / 2).detach().cpu().numpy() - ax.bar(xx, yy, width=.1) - plt.show() - """ - - MI = ( - torch.log(torch.tensor(self.n_bins, device=device) + epsilon) - + (amp_probs * (amp_probs + epsilon).log()).sum(dim=-1) - ) / torch.log(torch.tensor(self.n_bins, device=device)) - - # Squeeze the n_bin dimension - MI = MI.squeeze(-1) - - # Takes mean along the n_segments dimension - i_segment = -1 - MI = MI.mean(axis=i_segment) - - if MI.isnan().any(): - warnings.warn("NaN values detected in Modulation Index calculation.") - # raise ValueError( - # "NaN values detected in Modulation Index calculation." - # ) - - return MI - - @staticmethod - def _phase_to_masks(pha, phase_bin_cutoffs): - n_bins = int(len(phase_bin_cutoffs) - 1) - bin_indices = ( - (torch.bucketize(pha, phase_bin_cutoffs, right=False) - 1).clamp( - 0, n_bins - 1 - ) - ).long() - one_hot_masks = ( - F.one_hot( - bin_indices, - num_classes=n_bins, - ) - .bool() - .to(pha.device) - ) - return one_hot_masks - - -def _reshape(x, batch_size=2, n_chs=4): - return ( - torch.tensor(x) - .float() - .unsqueeze(0) - .unsqueeze(0) - .repeat(batch_size, n_chs, 1, 1, 1) - ) - - -if __name__ == "__main__": - import matplotlib.pyplot as plt - - import scitex - - # Start - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start( - sys, plt, fig_scale=3 - ) - - # Parameters - FS = 512 - T_SEC = 1 - device = "cuda" - - # Demo signal - xx, tt, fs = scitex.dsp.demo_sig(fs=FS, t_sec=T_SEC, sig_type="tensorpac") - # xx.shape: (8, 19, 20, 512) - - # Tensorpac - ( - pha, - amp, - freqs_pha, - freqs_amp, - pac_tp, - ) = scitex.dsp.utils.pac.calc_pac_with_tensorpac(xx, fs, t_sec=T_SEC) - - # GPU calculation with scitex.dsp.nn.ModulationIndex - pha, amp = _reshape(pha), _reshape(amp) - - m = ModulationIndex(n_bins=18, fp16=True).to(device) - - pac_scitex = m(pha.to(device), amp.to(device)) - - # pac_scitex = scitex.dsp.modulation_index(pha, amp).cpu().numpy() - i_batch, i_ch = 0, 0 - pac_scitex = pac_scitex[i_batch, i_ch].squeeze().numpy() - - # Plots - fig = scitex.dsp.utils.pac.plot_PAC_scitex_vs_tensorpac( - pac_scitex, pac_tp, freqs_pha, freqs_amp - ) - # fig = plot_PAC_scitex_vs_tensorpac(pac_scitex, pac_tp, freqs_pha, freqs_amp) - scitex.io.save(fig, CONFIG["SDIR"] + "modulation_index.png") # plt.show() - - # Close - scitex.session.close(CONFIG) - -# EOF - -""" -/home/ywatanabe/proj/entrance/scitex/nn/_ModulationIndex.py -""" - - -# EOF diff --git a/src/scitex/nn/_PAC.py b/src/scitex/nn/_PAC.py deleted file mode 100755 index 73df47aa3..000000000 --- a/src/scitex/nn/_PAC.py +++ /dev/null @@ -1,416 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-11-26 10:33:30 (ywatanabe)" -# File: ./scitex_repo/src/scitex/nn/_PAC.py - -THIS_FILE = "/home/ywatanabe/proj/scitex_repo/src/scitex/nn/_PAC.py" - -# Imports -import sys -import warnings - -import matplotlib.pyplot as plt -import torch -import torch.nn as nn - -import scitex - - -# Functions -class PAC(nn.Module): - def __init__( - self, - seq_len, - fs, - pha_start_hz=2, - pha_end_hz=20, - pha_n_bands=50, - amp_start_hz=60, - amp_end_hz=160, - amp_n_bands=30, - n_perm=None, - trainable=False, - in_place=True, - fp16=False, - amp_prob=False, - ): - super().__init__() - - self.fp16 = fp16 - self.n_perm = n_perm - self.amp_prob = amp_prob - self.trainable = trainable - - if n_perm is not None: - if not isinstance(n_perm, int): - raise ValueError("n_perm should be None or an integer.") - - # caps amp_end_hz - factor = 0.8 - amp_end_hz = int(min(fs / 2 / (1 + factor) - 1, amp_end_hz)) - - self.bandpass = self.init_bandpass( - seq_len, - fs, - pha_start_hz=pha_start_hz, - pha_end_hz=pha_end_hz, - pha_n_bands=pha_n_bands, - amp_start_hz=amp_start_hz, - amp_end_hz=amp_end_hz, - amp_n_bands=amp_n_bands, - fp16=fp16, - trainable=trainable, - ) - - self.hilbert = scitex.nn.Hilbert(seq_len, dim=-1, fp16=fp16) - - self.Modulation_index = scitex.nn.ModulationIndex( - n_bins=18, - fp16=fp16, - amp_prob=amp_prob, - ) - - # Data Handlers - self.dh_pha = scitex.gen.DimHandler() - self.dh_amp = scitex.gen.DimHandler() - - def forward(self, x): - """x.shape: (batch_size, n_chs, seq_len) or (batch_size, n_chs, n_segments, seq_len)""" - - with torch.set_grad_enabled(bool(self.trainable)): - x = self._ensure_4d_input(x) - # (batch_size, n_chs, n_segments, seq_len) - - batch_size, n_chs, n_segments, seq_len = x.shape - - x = x.reshape(batch_size * n_chs, n_segments, seq_len) - # (batch_size * n_chs, n_segments, seq_len) - - x = self.bandpass(x, edge_len=0) - # (batch_size*n_chs, n_segments, n_pha_bands + n_amp_bands, seq_len) - - x = self.hilbert(x) - # (batch_size*n_chs, n_segments, n_pha_bands + n_amp_bands, seq_len, pha + amp) - - x = x.reshape(batch_size, n_chs, *x.shape[1:]) - # (batch_size, n_chs, n_segments, n_pha_bands + n_amp_bands, seq_len, pha + amp) - - x = x.transpose(2, 3) - # (batch_size, n_chs, n_pha_bands + n_amp_bands, n_segments, pha + amp) - - if self.fp16: - x = x.half() - - pha = x[:, :, : len(self.PHA_MIDS_HZ), :, :, 0] - # (batch_size, n_chs, n_freqs_pha, n_segments, sequence_length) - - amp = x[:, :, -len(self.AMP_MIDS_HZ) :, :, :, 1] - # (batch_size, n_chs, n_freqs_amp, n_segments, sequence_length)() - - edge_len = int(pha.shape[-1] // 8) - - pha = pha[..., edge_len:-edge_len].half() - amp = amp[..., edge_len:-edge_len].half() - - pac_or_amp_prob = self.Modulation_index(pha, amp) # .squeeze() - # print(pac_or_amp_prob.shape) - # pac_or_amp_prob = pac_or_amp_prob.squeeze() - - if self.n_perm is None: - return pac_or_amp_prob - else: - return self.to_z_using_surrogate(pha, amp, pac_or_amp_prob) - - def to_z_using_surrogate(self, pha, amp, observed): - surrogates = self.generate_surrogates(pha, amp) - mm = surrogates.mean(dim=2).to(observed.device) - ss = surrogates.std(dim=2).to(observed.device) - return (observed - mm) / (ss + 1e-5) - - # if self.amp_prob: - # amp_prob = self.Modulation_index(pha, amp).squeeze() - # amp_prob.shape # torch.Size([2, 8, 50, 50, 3, 18]) - # pac_surrogates = self.generate_surrogates(pha, amp) - # # torch.Size([2, 8, 3, 50, 50, 3, 18]) - # __import__("ipdb").set_trace() - # return amp_prob - - # elif not self.amp_prob: - # pac = self.Modulation_index(pha, amp).squeeze() # torch.Size([2, 8, 50, 50]) - - # if self.n_perm is not None: - # pac_surrogates = self.generate_surrogates(pha, amp) - # # torch.Size([2, 8, 3, 50, 50]) # self.amp_prob = False - # __import__("ipdb").set_trace() - # mm = pac_surrogates.mean(dim=2).to(pac.device) - # ss = pac_surrogates.std(dim=2).to(pac.device) - # pac_z = (pac - mm) / (ss + 1e-5) - # return pac_z - - # return pac - - def generate_surrogates(self, pha, amp, bs=1): - # Shape of pha: [batch_size, n_chs, n_freqs_pha, n_segments, sequence_length] - batch_size, n_chs, n_freqs_pha, n_segments, seq_len = pha.shape - _, _, n_freqs_amp, _, _ = amp.shape - - # cut and shuffle - cut_points = torch.randint(seq_len, (self.n_perm,), device=pha.device) - ranges = torch.arange(seq_len, device=pha.device) - indices = cut_points.unsqueeze(0) - ranges.unsqueeze(1) - - pha = pha[..., indices] - amp = amp.unsqueeze(-1).expand(-1, -1, -1, -1, -1, self.n_perm) - - pha = self.dh_pha.fit(pha, keepdims=[2, 3, 4]) - amp = self.dh_amp.fit(amp, keepdims=[2, 3, 4]) - - if self.fp16: - pha = pha.half() - amp = amp.half() - - # print("\nCalculating surrogate PAC values...") - - surrogate_pacs = [] - n_batches = (len(pha) + bs - 1) // bs - device = "cuda" - with torch.no_grad(): - # ######################################## - # # fixme - # pha = pha.to(device) - # amp = amp.to(device) - # ######################################## - - for i_batch in range(n_batches): - start = i_batch * bs - end = min((i_batch + 1) * bs, pha.shape[0]) - - _pha = pha[start:end].unsqueeze(1).to(device) # n_chs = 1 - _amp = amp[start:end].unsqueeze(1).to(device) # n_chs = 1 - - _surrogate_pacs = self.Modulation_index(_pha, _amp).cpu() - surrogate_pacs.append(_surrogate_pacs) - - # # Optionally clear cache if memory is an issue - # torch.cuda.empty_cache() - - torch.cuda.empty_cache() - surrogate_pacs = torch.vstack(surrogate_pacs).squeeze() - surrogate_pacs = self.dh_pha.unfit(surrogate_pacs) - - return surrogate_pacs - - def init_bandpass( - self, - seq_len, - fs, - pha_start_hz=2, - pha_end_hz=20, - pha_n_bands=50, - amp_start_hz=60, - amp_end_hz=160, - amp_n_bands=30, - trainable=False, - fp16=False, - ): - # A static, gen purpose BandPassFilter - if not trainable: - # First, bands definitions for phase and amplitude are declared - self.BANDS_PHA = self.calc_bands_pha( - start_hz=pha_start_hz, - end_hz=pha_end_hz, - n_bands=pha_n_bands, - ) - self.BANDS_AMP = self.calc_bands_amp( - start_hz=amp_start_hz, - end_hz=amp_end_hz, - n_bands=amp_n_bands, - ) - bands_all = torch.vstack([self.BANDS_PHA, self.BANDS_AMP]) - - # Instanciation of the static bandpass filter module - self.bandpass = scitex.nn.BandPassFilter( - bands_all, - fs, - seq_len, - fp16=fp16, - ) - self.PHA_MIDS_HZ = self.BANDS_PHA.mean(-1) - self.AMP_MIDS_HZ = self.BANDS_AMP.mean(-1) - - # A trainable BandPassFilter specifically for PAC calculation. Bands will be optimized. - elif trainable: - self.bandpass = scitex.nn.DifferentiableBandPassFilter( - seq_len, - fs, - fp16=fp16, - pha_low_hz=pha_start_hz, - pha_high_hz=pha_end_hz, - pha_n_bands=pha_n_bands, - amp_low_hz=amp_start_hz, - amp_high_hz=amp_end_hz, - amp_n_bands=amp_n_bands, - ) - self.PHA_MIDS_HZ = self.bandpass.pha_mids - self.AMP_MIDS_HZ = self.bandpass.amp_mids - - return self.bandpass - - @staticmethod - def calc_bands_pha(start_hz=2, end_hz=20, n_bands=100): - start_hz = start_hz if start_hz is not None else 2 - end_hz = end_hz if end_hz is not None else 20 - mid_hz = torch.linspace(start_hz, end_hz, n_bands) - return torch.cat( - ( - mid_hz.unsqueeze(1) - mid_hz.unsqueeze(1) / 4.0, - mid_hz.unsqueeze(1) + mid_hz.unsqueeze(1) / 4.0, - ), - dim=1, - ) - - @staticmethod - def calc_bands_amp(start_hz=30, end_hz=160, n_bands=100): - start_hz = start_hz if start_hz is not None else 30 - end_hz = end_hz if end_hz is not None else 160 - mid_hz = torch.linspace(start_hz, end_hz, n_bands) - return torch.cat( - ( - mid_hz.unsqueeze(1) - mid_hz.unsqueeze(1) / 8.0, - mid_hz.unsqueeze(1) + mid_hz.unsqueeze(1) / 8.0, - ), - dim=1, - ) - - @staticmethod - def _ensure_4d_input(x): - if x.ndim != 4: - message = f"Input tensor must be 4D with the shape (batch_size, n_chs, n_segments, seq_len). Received shape: {x.shape}" - - if x.ndim == 3: - # warnings.warn( - # "'n_segments' was determined to be 1, assuming your input is (batch_size, n_chs, seq_len).", - # UserWarning, - # ) - x = x.unsqueeze(-2) - - if x.ndim != 4: - raise ValueError(message) - - return x - - -if __name__ == "__main__": - # Start - CONFIG, sys.stdout, sys.stderr, plt, CC = scitex.session.start(sys, plt) - - ts = scitex.gen.TimeStamper() - - # Parameters - FS = 512 - T_SEC = 8 - PLOT = False - fp16 = True - trainable = False - n_perm = 3 - in_place = True - amp_prob = True - - # Demo Signal - xx, tt, fs = scitex.dsp.demo_sig( - batch_size=2, - n_chs=8, - n_segments=3, - fs=FS, - t_sec=T_SEC, - sig_type="tensorpac", - # sig_type="pac", - ) - xx = torch.tensor(xx).cuda() - xx.requires_grad = False - # (2, 8, 2, 4096) - - # PAC object initialization - ts("PAC initialization starts") - m = PAC( - xx.shape[-1], - fs, - pha_start_hz=2, - pha_end_hz=20, - pha_n_bands=50, - amp_start_hz=60, - amp_end_hz=160, - amp_n_bands=50, - fp16=fp16, - trainable=trainable, - n_perm=n_perm, - in_place=in_place, - amp_prob=amp_prob, - ).cuda() - ts("PAC initialization ends") - - # PAC calculation - ts("PAC calculation starts") - pac = m(xx) - ts("PAC calculation ends") - - """ - amp_prob = m(xx) - amp_prob = amp_prob.reshape(-1, amp_prob.shape[-1]) - xx = m.Modulation_index.pha_bin_centers - plt.bar(xx, amp_prob[0]) - """ - - scitex.gen.print_block( - f"PAC calculation time: {ts.delta(-1, -2):.3f} sec", c="yellow" - ) - # 0.17 sec - scitex.gen.print_block( - f"x.shape: {xx.shape}" - f"\nfp16: {fp16}" - f"\ntrainable: {trainable}" - f"\nn_perm: {n_perm}" - f"\nin_place: {in_place}" - ) - - # # Plots - # if PLOT: - # pac = pac.detach().cpu().numpy() - # fig, ax = scitex.plt.subplots() - # ax.imshow2d(pac[0, 0], cbar_label="PAC value [zscore]") - # ax.set_ticks( - # x_vals=m.PHA_MIDS_HZ, - # x_ticks=np.linspace(m.PHA_MIDS_HZ[0], m.PHA_MIDS_HZ[-1], 4), - # y_vals=m.AMP_MIDS_HZ, - # y_ticks=np.linspace(m.AMP_MIDS_HZ[0], m.AMP_MIDS_HZ[-1], 4), - # ) - # ax.set_xyt( - # "Frequency for phase [Hz]", - # "Amplitude for phase [Hz]", - # "PAC values", - # ) - # plt.show() - - -# EOF - -""" -/home/ywatanabe/proj/entrance/scitex/dsp/nn/_PAC.py -""" - -# # close -# fig, axes = scitex.plt.subplots(ncols=2) -# axes[0].imshow2d(pac_scitex[i_batch, i_ch]) -# axes[1].imshow2d(pac_tp) -# scitex.io.save(fig, CONFIG["SDIR"] + "pac.png") -# import numpy as np -# np.corrcoef(pac_scitex[i_batch, i_ch], pac_tp)[0, 1] -# import matplotlib - -# plt.close("all") -# matplotlib.use("TkAgg") -# plt.scatter(pac_scitex[i_batch, i_ch].reshape(-1), pac_tp.reshape(-1)) -# plt.show() - - -# EOF diff --git a/src/scitex/nn/_PSD.py b/src/scitex/nn/_PSD.py deleted file mode 100755 index ed2554737..000000000 --- a/src/scitex/nn/_PSD.py +++ /dev/null @@ -1,39 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-04-11 21:50:09 (ywatanabe)" - -import torch -import torch.nn as nn - - -class PSD(nn.Module): - def __init__(self, sample_rate, prob=False, dim=-1): - super(PSD, self).__init__() - self.sample_rate = sample_rate - self.dim = dim - self.prob = prob - - def forward(self, signal): - is_complex = signal.is_complex() - if is_complex: - signal_fft = torch.fft.fft(signal, dim=self.dim) - freqs = torch.fft.fftfreq(signal.size(self.dim), 1 / self.sample_rate).to( - signal.device - ) - - else: - signal_fft = torch.fft.rfft(signal, dim=self.dim) - freqs = torch.fft.rfftfreq(signal.size(self.dim), 1 / self.sample_rate).to( - signal.device - ) - - power_spectrum = torch.abs(signal_fft) ** 2 - power_spectrum = power_spectrum / signal.size(self.dim) - - psd = power_spectrum * (1.0 / self.sample_rate) - - # To probability if specified - if self.prob: - psd /= psd.sum(dim=self.dim, keepdims=True) - - return psd, freqs diff --git a/src/scitex/nn/_ResNet1D.py b/src/scitex/nn/_ResNet1D.py deleted file mode 100755 index 0e86c4be7..000000000 --- a/src/scitex/nn/_ResNet1D.py +++ /dev/null @@ -1,120 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2023-05-15 16:46:54 (ywatanabe)" - -import torch -import torch.nn as nn -from torchsummary import summary - - -class ResNet1D(nn.Module): - """ - A representative convolutional neural network for signal classification tasks. - """ - - def __init__(self, n_chs=19, n_out=10, n_blks=5): - super().__init__() - - # Parameters - N_CHS = n_chs - _N_FILTS_PER_CH = 4 - N_FILTS = N_CHS * _N_FILTS_PER_CH - N_BLKS = n_blks - - # Convolutional layers - self.res_conv_blk_layers = nn.Sequential( - ResNetBasicBlock(N_CHS, N_FILTS), - *[ResNetBasicBlock(N_FILTS, N_FILTS) for _ in range(N_BLKS - 1)], - ) - - # ## FC layer - # self.fc = nn.Sequential( - # nn.Linear(N_FILTS, 64), - # nn.ReLU(), - # nn.Dropout(p=0.5), - # nn.Linear(64, 32), - # nn.ReLU(), - # nn.Dropout(p=0.5), - # nn.Linear(32, n_out), - # ) - - def forward(self, x): - x = self.res_conv_blk_layers(x) - # x = x.mean(axis=-1) - # x = self.fc(x) - return x - - -class ResNetBasicBlock(nn.Module): - """The basic block of the ResNet1D model""" - - def __init__(self, in_chs, out_chs): - super(ResNetBasicBlock, self).__init__() - self.in_chs = in_chs - self.out_chs = out_chs - - self.conv7 = self.conv_k(in_chs, out_chs, k=7, p=3) - self.bn7 = nn.BatchNorm1d(out_chs) - self.activation7 = nn.ReLU() - - self.conv5 = self.conv_k(out_chs, out_chs, k=5, p=2) - self.bn5 = nn.BatchNorm1d(out_chs) - self.activation5 = nn.ReLU() - - self.conv3 = self.conv_k(out_chs, out_chs, k=3, p=1) - self.bn3 = nn.BatchNorm1d(out_chs) - self.activation3 = nn.ReLU() - - self.expansion_conv = self.conv_k(in_chs, out_chs, k=1, p=0) - - self.bn = nn.BatchNorm1d(out_chs) - self.activation = nn.ReLU() - - @staticmethod - def conv_k(in_chs, out_chs, k=1, s=1, p=1): - """Build size k kernel's convolution layer with padding""" - return nn.Conv1d( - in_chs, out_chs, kernel_size=k, stride=s, padding=p, bias=False - ) - - def forward(self, x): - residual = x - - x = self.conv7(x) - x = self.bn7(x) - x = self.activation7(x) - - x = self.conv5(x) - x = self.bn5(x) - x = self.activation5(x) - - x = self.conv3(x) - x = self.bn3(x) - x = self.activation3(x) - - if self.in_chs != self.out_chs: - residual = self.expansion_conv(residual) - residual = self.bn(residual) - - x = x + residual - x = self.activation(x) - - return x - - -if __name__ == "__main__": - import sys - - sys.path.append("./DEAP/") - import utils - - # Demo data - bs, n_chs, seq_len = 16, 32, 8064 - Xb = torch.rand(bs, n_chs, seq_len) - - model = ResNet1D( - n_chs=n_chs, - n_out=4, - ) # utils.load_yaml("./config/global.yaml")["EMOTIONS"] - y = model(Xb) # 16,4 - summary(model, Xb) diff --git a/src/scitex/nn/_SpatialAttention.py b/src/scitex/nn/_SpatialAttention.py deleted file mode 100755 index 4da131081..000000000 --- a/src/scitex/nn/_SpatialAttention.py +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2023-04-23 09:45:28 (ywatanabe)" - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchsummary import summary - -import scitex - - -class SpatialAttention(nn.Module): - def __init__(self, n_chs_in): - super().__init__() - self.aap = nn.AdaptiveAvgPool1d(1) - self.conv11 = nn.Conv1d(in_channels=n_chs_in, out_channels=1, kernel_size=1) - - def forward(self, x): - """x: [batch_size, n_chs, seq_len]""" - x_orig = x - x = self.aap(x) - x = self.conv11(x) - - return x * x_orig diff --git a/src/scitex/nn/_Spectrogram.py b/src/scitex/nn/_Spectrogram.py deleted file mode 100755 index ddc734042..000000000 --- a/src/scitex/nn/_Spectrogram.py +++ /dev/null @@ -1,170 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-04-02 09:21:12 (ywatanabe)" - -import matplotlib.pyplot as plt -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -import scitex -from scitex.decorators import numpy_fn, torch_fn - - -class Spectrogram(nn.Module): - def __init__( - self, - sampling_rate, - n_fft=256, - hop_length=None, - win_length=None, - window="hann", - ): - super().__init__() - self.sampling_rate = sampling_rate - self.n_fft = n_fft - self.hop_length = hop_length if hop_length is not None else n_fft // 4 - self.win_length = win_length if win_length is not None else n_fft - if window == "hann": - self.window = torch.hann_window(window_length=self.win_length) - else: - raise ValueError( - "Unsupported window type. Extend this to support more window types." - ) - - def forward(self, x): - """ - Computes the spectrogram for each channel in the input signal. - - Parameters: - - signal (torch.Tensor): Input signal of shape (batch_size, n_chs, seq_len). - - Returns: - - spectrograms (torch.Tensor): The computed spectrograms for each channel. - """ - - x = scitex.dsp.ensure_3d(x) - - batch_size, n_chs, seq_len = x.shape - spectrograms = [] - - for ch in range(n_chs): - x_ch = x[:, ch, :].unsqueeze(1) # Maintain expected input shape for stft - spec = torch.stft( - x_ch.squeeze(1), - n_fft=self.n_fft, - hop_length=self.hop_length, - win_length=self.win_length, - window=self.window.to(x.device), - center=True, - pad_mode="reflect", - normalized=False, - return_complex=True, - ) - magnitude = torch.abs(spec).unsqueeze(1) # Keep channel dimension - spectrograms.append(magnitude) - - # Concatenate spectrograms along channel dimension - spectrograms = torch.cat(spectrograms, dim=1) - - # Calculate frequencies (y-axis) - freqs = torch.linspace(0, self.sampling_rate / 2, steps=self.n_fft // 2 + 1) - - # Calculate times (x-axis) - # The number of frames can be computed from the size of the last dimension of the spectrogram - n_frames = spectrograms.shape[-1] - # Time of each frame in seconds, considering the hop length and sampling rate - times_sec = torch.arange(0, n_frames) * (self.hop_length / self.sampling_rate) - - return spectrograms, freqs, times_sec - - -@torch_fn -def spectrograms(x, fs, cuda=False): - return Spectrogram(fs)(x) - - -@torch_fn -def my_softmax(x, dim=-1): - return F.softmax(x, dim=dim) - - -@torch_fn -def unbias(x, func="min", dim=-1, cuda=False): - if func == "min": - return x - x.min(dim=dim, keepdims=True)[0] - if func == "mean": - return x - x.mean(dim=dim, keepdims=True)[0] - - -@torch_fn -def normalize(x, axis=-1, amp=1.0, cuda=False): - high = torch.abs(x.max(axis=axis, keepdims=True)[0]) - low = torch.abs(x.min(axis=axis, keepdims=True)[0]) - return amp * x / torch.maximum(high, low) - - -@torch_fn -def spectrograms(x, fs, dj=0.125, cuda=False): - try: - from wavelets_pytorch.transform import WaveletTransformTorch # PyTorch version - except ImportError: - raise ImportError( - "The spectrograms function requires the wavelets-pytorch package. " - "Install it with: pip install wavelets-pytorch" - ) - - dt = 1 / fs - # dj = 0.125 - batch_size, n_chs, seq_len = x.shape - - x = x.cpu().numpy() - - # # Batch of signals to process - # batch = np.array([batch_size * seq_len]) - - # Initialize wavelet filter banks (scipy and torch implementation) - # wa_scipy = WaveletTransform(dt, dj) - wa_torch = WaveletTransformTorch(dt, dj, cuda=True) - - # Performing wavelet transform (and compute scalogram) - # cwt_scipy = wa_scipy.cwt(batch) - x = x[:, 0][:, np.newaxis] - cwt_torch = wa_torch.cwt(x) - - return cwt_torch - - -if __name__ == "__main__": - import seaborn as sns - import torchaudio - - import scitex - - fs = 1024 # 128 - t_sec = 10 - x = scitex.dsp.np.demo_sig(t_sec=t_sec, fs=fs, type="ripple") - - normalize(unbias(x, cuda=True), cuda=True) - - # My implementtion - ss = spectrograms(x, fs, cuda=True) - fig, axes = plt.subplots(nrows=2) - axes[0].plot(np.arange(x[0, 0]) / fs, x[0, 0]) - sns.heatmap(ss[0], ax=axes[1]) - plt.show() - - ss, ff, tt = spectrograms(x, fs, cuda=True) - fig, axes = plt.subplots(nrows=2) - axes[0].plot(np.arange(x[0, 0]) / fs, x[0, 0]) - sns.heatmap(ss[0], ax=axes[1]) - plt.show() - - # Torch Audio - transform = torchaudio.transforms.Spectrogram(n_fft=16, normalized=True).cuda() - xx = torch.tensor(x).float().cuda()[0, 0] - ss = transform(xx) - sns.heatmap(ss.detach().cpu().numpy()) - - plt.show() diff --git a/src/scitex/nn/_SwapChannels.py b/src/scitex/nn/_SwapChannels.py deleted file mode 100755 index 7ba08cbdf..000000000 --- a/src/scitex/nn/_SwapChannels.py +++ /dev/null @@ -1,52 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2023-05-04 21:21:19 (ywatanabe)" - -import random - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from torchsummary import summary - -import scitex - - -class SwapChannels(nn.Module): - def __init__(self, dropout=0.5): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - - def forward(self, x): - """x: [batch_size, n_chs, seq_len]""" - if self.training: - orig_chs = torch.arange(x.shape[1]) - - indi_orig = self.dropout(torch.ones(x.shape[1])).bool() - chs_to_shuffle = orig_chs[~indi_orig] - - rand_chs = random.sample( - list(np.array(chs_to_shuffle)), len(chs_to_shuffle) - ) - - swapped_chs = orig_chs.clone() - swapped_chs[~indi_orig] = torch.LongTensor(rand_chs) - - x = x[:, swapped_chs.long(), :] - - return x - - -if __name__ == "__main__": - ## Demo data - bs, n_chs, seq_len = 16, 360, 1000 - x = torch.rand(bs, n_chs, seq_len) - - sc = SwapChannels() - print(sc(x).shape) # [16, 19, 1000] - - # sb = SubjectBlock(n_chs=n_chs) - # print(sb(x, s).shape) # [16, 270, 1000] - - # summary(sb, x, s) diff --git a/src/scitex/nn/_TransposeLayer.py b/src/scitex/nn/_TransposeLayer.py deleted file mode 100755 index 4b5eea607..000000000 --- a/src/scitex/nn/_TransposeLayer.py +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-03-30 07:26:35 (ywatanabe)" - -import torch.nn as nn - - -class TransposeLayer(nn.Module): - def __init__( - self, - axis1, - axis2, - ): - super().__init__() - self.axis1 = axis1 - self.axis2 = axis2 - - def forward(self, x): - return x.transpose(self.axis1, self.axis2) diff --git a/src/scitex/nn/_Wavelet.py b/src/scitex/nn/_Wavelet.py deleted file mode 100755 index 7338dd569..000000000 --- a/src/scitex/nn/_Wavelet.py +++ /dev/null @@ -1,185 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-11-03 07:17:26 (ywatanabe)" -# File: ./scitex_repo/src/scitex/nn/_Wavelet.py - -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Time-stamp: "2024-05-30 11:04:45 (ywatanabe)" - - -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F - -import scitex -from scitex.gen._to_even import to_even -from scitex.gen._to_odd import to_odd - - -class Wavelet(nn.Module): - def __init__( - self, samp_rate, kernel_size=None, freq_scale="linear", out_scale="log" - ): - super().__init__() - self.register_buffer("dummy", torch.tensor(0)) - self.kernel = None - self.init_kernel(samp_rate, kernel_size=kernel_size, freq_scale=freq_scale) - self.out_scale = out_scale - - def forward(self, x): - """Apply the 2D filter (n_filts, kernel_size) to input signal x with shape: (batch_size, n_chs, seq_len)""" - x = scitex.dsp.ensure_3d(x).to(self.dummy.device) - seq_len = x.shape[-1] - - # Ensure the kernel is initialized - if self.kernel is None: - self.init_kernel() - if self.kernel is None: - raise ValueError("Filter kernel has not been initialized.") - assert self.kernel.ndim == 2 - self.kernel = self.kernel.to(x.device) # cuda, torch.complex128 - - # Edge handling and convolution - extension_length = self.radius - first_segment = x[:, :, :extension_length].flip(dims=[-1]) - last_segment = x[:, :, -extension_length:].flip(dims=[-1]) - extended_x = torch.cat([first_segment, x, last_segment], dim=-1) - - # working?? - kernel_batched = self.kernel.unsqueeze(1) - extended_x_reshaped = extended_x.view(-1, 1, extended_x.shape[-1]) - - filtered_x_real = F.conv1d( - extended_x_reshaped, kernel_batched.real.float(), groups=1 - ) - filtered_x_imag = F.conv1d( - extended_x_reshaped, kernel_batched.imag.float(), groups=1 - ) - - filtered_x = torch.view_as_complex( - torch.stack([filtered_x_real, filtered_x_imag], dim=-1) - ) - - filtered_x = filtered_x.view( - x.shape[0], x.shape[1], kernel_batched.shape[0], -1 - ) - filtered_x = filtered_x.view( - x.shape[0], x.shape[1], kernel_batched.shape[0], -1 - ) - filtered_x = filtered_x[..., :seq_len] - assert filtered_x.shape[-1] == seq_len - - pha = filtered_x.angle() - amp = filtered_x.abs() - - # Repeats freqs - freqs = ( - self.freqs.unsqueeze(0).unsqueeze(0).repeat(pha.shape[0], pha.shape[1], 1) - ) - - if self.out_scale == "log": - return pha, torch.log(amp + 1e-5), freqs - else: - return pha, amp, freqs - - def init_kernel(self, samp_rate, kernel_size=None, freq_scale="log"): - device = self.dummy.device - morlets, freqs = self.gen_morlet_to_nyquist( - samp_rate, kernel_size=kernel_size, freq_scale=freq_scale - ) - self.kernel = torch.tensor(morlets).to(device) - self.freqs = torch.tensor(freqs).float().to(device) - - @staticmethod - def gen_morlet_to_nyquist(samp_rate, kernel_size=None, freq_scale="linear"): - """ - Generates Morlet wavelets for exponentially increasing frequency bands up to the Nyquist frequency. - - Parameters: - - samp_rate (int): The sampling rate of the signal, in Hertz. - - kernel_size (int): The size of the kernel, in number of samples. - - Returns: - - np.ndarray: A 2D array of complex values representing the Morlet wavelets for each frequency band. - """ - if kernel_size is None: - kernel_size = int(samp_rate) # * 2.5) - - nyquist_freq = samp_rate / 2 - - # Log freq_scale - def calc_freq_boundaries_log(nyquist_freq): - n_kernels = int(np.floor(np.log2(nyquist_freq))) - mid_hz = np.array([2 ** (n + 1) for n in range(n_kernels)]) - width_hz = np.hstack([np.array([1]), np.diff(mid_hz) / 2]) + 1 - low_hz = mid_hz - width_hz - high_hz = mid_hz + width_hz - low_hz[0] = 0.1 - return low_hz, high_hz - - def calc_freq_boundaries_linear(nyquist_freq): - n_kernels = int(nyquist_freq) - high_hz = np.linspace(1, nyquist_freq, n_kernels) - low_hz = high_hz - np.hstack([np.array(1), np.diff(high_hz)]) - low_hz[0] = 0.1 - return low_hz, high_hz - - if freq_scale == "linear": - fn = calc_freq_boundaries_linear - if freq_scale == "log": - fn = calc_freq_boundaries_log - low_hz, high_hz = fn(nyquist_freq) - - morlets = [] - freqs = [] - - for _, (ll, hh) in enumerate(zip(low_hz, high_hz)): - if ll > nyquist_freq: - break - - center_frequency = (ll + hh) / 2 - - t = np.arange(-kernel_size // 2, kernel_size // 2) / samp_rate - # Calculate standard deviation of the gaussian window for a given center frequency - sigma = 7 / (2 * np.pi * center_frequency) - sine_wave = np.exp(2j * np.pi * center_frequency * t) - gaussian_window = np.exp(-(t**2) / (2 * sigma**2)) - morlet_wavelet = sine_wave * gaussian_window - - freqs.append(center_frequency) - morlets.append(morlet_wavelet) - - return np.array(morlets), np.array(freqs) - - @property - def kernel_size( - self, - ): - return to_even(self.kernel.shape[-1]) - - @property - def radius( - self, - ): - return to_even(self.kernel_size // 2) - - -if __name__ == "__main__": - import matplotlib.pyplot as plt - - import scitex - - xx, tt, fs = scitex.dsp.demo_sig(sig_type="chirp") - - pha, amp, ff = scitex.dsp.wavelet(xx, fs) - - fig, ax = scitex.plt.subplots() - ax.imshow2d(amp[0, 0].T) - ax = scitex.plt.ax.set_ticks(ax, xticks=tt, yticks=ff) - ax = scitex.plt.ax.set_n_ticks(ax) - plt.show() - - -# EOF diff --git a/src/scitex/nn/__init__.py b/src/scitex/nn/__init__.py index 0a65c26ee..e1eb98081 100755 --- a/src/scitex/nn/__init__.py +++ b/src/scitex/nn/__init__.py @@ -1,75 +1,13 @@ -#!/usr/bin/env python3 -"""Scitex nn module.""" +"""SciTeX nn — thin compatibility shim for scitex-nn.""" -from ._AxiswiseDropout import AxiswiseDropout -from ._BNet import BHead as BHead_v1 -from ._BNet import BNet as BNet_v1 -from ._BNet import BNet_config as BNet_config_v1 -from ._BNet_Res import BHead as BHead_Res -from ._BNet_Res import BNet as BNet_Res -from ._BNet_Res import BNet_config as BNet_config_Res -from ._ChannelGainChanger import ChannelGainChanger -from ._DropoutChannels import DropoutChannels -from ._Filters import ( - BandPassFilter, - BandStopFilter, - BaseFilter1D, - DifferentiableBandPassFilter, - GaussianFilter, - HighPassFilter, - LowPassFilter, -) -from ._FreqGainChanger import FreqGainChanger +import sys as _sys -# Removed duplicate GaussianFilter import - already imported from _Filters -from ._Hilbert import Hilbert -from ._MNet_1000 import MNet1000, MNet_1000, MNet_config, ReshapeLayer, SwapLayer -from ._ModulationIndex import ModulationIndex -from ._PAC import PAC -from ._PSD import PSD -from ._ResNet1D import ResNet1D, ResNetBasicBlock -from ._SpatialAttention import SpatialAttention -from ._Spectrogram import Spectrogram, my_softmax, normalize, spectrograms, unbias -from ._SwapChannels import SwapChannels -from ._TransposeLayer import TransposeLayer -from ._Wavelet import Wavelet +try: + import scitex_nn as _real +except ImportError as _e: + raise ImportError( + "scitex.nn requires the 'scitex-nn' package. " + "Install with: pip install scitex[nn] (or: pip install scitex-nn)" + ) from _e -__all__ = [ - "AxiswiseDropout", - "BHead_v1", - "BHead_Res", - "BNet_v1", - "BNet_Res", - "BNet_config_v1", - "BNet_config_Res", - "BandPassFilter", - "BandStopFilter", - "BaseFilter1D", - "ChannelGainChanger", - "DifferentiableBandPassFilter", - "DropoutChannels", - "FreqGainChanger", - "GaussianFilter", - "HighPassFilter", - "Hilbert", - "LowPassFilter", - "MNet1000", - "MNet_1000", - "MNet_config", - "ModulationIndex", - "PAC", - "PSD", - "ResNet1D", - "ResNetBasicBlock", - "ReshapeLayer", - "SpatialAttention", - "Spectrogram", - "SwapChannels", - "SwapLayer", - "TransposeLayer", - "Wavelet", - "my_softmax", - "normalize", - "spectrograms", - "unbias", -] +_sys.modules[__name__] = _real diff --git a/src/scitex/nn/_skills/SKILL.md b/src/scitex/nn/_skills/SKILL.md deleted file mode 100644 index 9d4b8425a..000000000 --- a/src/scitex/nn/_skills/SKILL.md +++ /dev/null @@ -1,109 +0,0 @@ ---- -name: stx.nn -description: Neural network layers and modules for neuroscience signal processing (EEG, LFP, spectrogram, PAC). All layers are nn.Module subclasses that work inside standard PyTorch training pipelines. ---- - -# stx.nn - -The `stx.nn` module provides PyTorch `nn.Module` layers specialized for neuroscience signal processing. All layers accept 3-D tensors shaped `(batch_size, n_chs, seq_len)` unless noted, and are differentiable by default. - -## Sub-skills - -| File | Feature area | -|---|---| -| [filters.md](filters.md) | FIR filter layers: BandPassFilter, BandStopFilter, LowPassFilter, HighPassFilter, GaussianFilter, DifferentiableBandPassFilter, BaseFilter1D | -| [spectral.md](spectral.md) | Spectral analysis: Hilbert, PSD, Spectrogram, Wavelet | -| [pac.md](pac.md) | Phase-Amplitude Coupling: PAC, ModulationIndex | -| [architectures.md](architectures.md) | Complete models: ResNet1D, MNet1000, BNet_v1, BNet_Res | -| [augmentation.md](augmentation.md) | Training augmentation: DropoutChannels, SwapChannels, ChannelGainChanger, FreqGainChanger, AxiswiseDropout | -| [utility-layers.md](utility-layers.md) | Building blocks: SpatialAttention, TransposeLayer, GaussianFilter (radius-based), SwapLayer, ReshapeLayer | - -## Quick orientation - -```python -import scitex as stx -import torch -import numpy as np - -x = torch.randn(8, 19, 1024) # batch=8, channels=19, time=1024 - -# --- Filters --- -bands = np.array([[4.0, 8.0], [8.0, 13.0]]) -y = stx.nn.BandPassFilter(bands=bands, fs=256, seq_len=1024)(x) -# output: (8, 19, 2, 1024) — one output band per filter - -# --- Spectral --- -pha, amp, freqs = stx.nn.Wavelet(samp_rate=256)(x) -psd, freqs = stx.nn.PSD(sample_rate=256)(x) -out = stx.nn.Hilbert(seq_len=1024)(x) # (..., 2): [phase, amplitude] - -# --- PAC --- -pac_layer = stx.nn.PAC(seq_len=1024, fs=256) -pac = pac_layer(x) # (8, 19, n_pha_bands, n_amp_bands) - -# --- Architectures --- -model = stx.nn.ResNet1D(n_chs=19, n_blks=5) -model = stx.nn.MNet1000(stx.nn.MNet_config) - -# --- Augmentation (training only) --- -x = stx.nn.DropoutChannels(dropout=0.1)(x) -x = stx.nn.SwapChannels(dropout=0.5)(x) -x = stx.nn.ChannelGainChanger(n_chs=19)(x) - -# --- Utility --- -x = stx.nn.SpatialAttention(n_chs_in=19)(x) -x = stx.nn.TransposeLayer(axis1=1, axis2=2)(x) -``` - -## Exported names - -```python -# Filters -stx.nn.BaseFilter1D -stx.nn.BandPassFilter -stx.nn.BandStopFilter -stx.nn.LowPassFilter -stx.nn.HighPassFilter -stx.nn.GaussianFilter # from _Filters.py (sigma-based) -stx.nn.DifferentiableBandPassFilter - -# Spectral -stx.nn.Hilbert -stx.nn.PSD -stx.nn.Spectrogram -stx.nn.Wavelet -stx.nn.spectrograms # function (torch_fn decorated) -stx.nn.my_softmax # function -stx.nn.normalize # function -stx.nn.unbias # function - -# PAC -stx.nn.PAC -stx.nn.ModulationIndex - -# Architectures -stx.nn.ResNet1D -stx.nn.ResNetBasicBlock -stx.nn.MNet1000 -stx.nn.MNet_1000 # deprecated alias for MNet1000 -stx.nn.MNet_config # default config dict -stx.nn.BNet_v1 # alias BNet from _BNet.py -stx.nn.BNet_Res # alias BNet from _BNet_Res.py -stx.nn.BNet_config_v1 # default config dict (_BNet.py) -stx.nn.BNet_config_Res # default config dict (_BNet_Res.py) -stx.nn.BHead_v1 # BHead from _BNet.py -stx.nn.BHead_Res # BHead from _BNet_Res.py -stx.nn.SwapLayer -stx.nn.ReshapeLayer - -# Augmentation -stx.nn.AxiswiseDropout -stx.nn.DropoutChannels -stx.nn.SwapChannels -stx.nn.ChannelGainChanger -stx.nn.FreqGainChanger - -# Utility -stx.nn.SpatialAttention -stx.nn.TransposeLayer -``` diff --git a/src/scitex/nn/_skills/architectures.md b/src/scitex/nn/_skills/architectures.md deleted file mode 100644 index 8884ca6f1..000000000 --- a/src/scitex/nn/_skills/architectures.md +++ /dev/null @@ -1,175 +0,0 @@ ---- -description: Complete neural network architectures for biosignal classification — ResNet1D, MNet1000, BNet (v1 and Residual). Designed for multi-channel EEG/MEG/LFP data shaped (batch, n_chs, seq_len). ---- - -# stx.nn — Model Architectures - ---- - -## ResNet1D - -1-D residual convolutional network. Suitable as a general-purpose signal classifier backbone. - -```python -import torch -import scitex as stx - -model = stx.nn.ResNet1D( - n_chs=19, # input channels - n_out=10, # output classes (reserved for FC head — currently not built in) - n_blks=5, # number of ResNetBasicBlock residual blocks -) - -x = torch.randn(16, 19, 1024) # (batch, n_chs, seq_len) -y = model(x) # (batch, n_chs * 4, seq_len) — feature map, no FC head -``` - -### Notes -- Each block increases channels: block 0 maps `n_chs → n_chs * 4`; subsequent blocks keep `n_chs * 4`. -- The FC classification head is commented out in the current source; `forward()` returns the feature map. -- Use `ResNetBasicBlock` directly if you need a single residual stage inside another model. - ---- - -## ResNetBasicBlock - -A single residual block with three convolutions (k=7, k=5, k=3) and a channel-expansion shortcut. - -```python -block = stx.nn.ResNetBasicBlock( - in_chs=19, # input channels - out_chs=76, # output channels (commonly 4× in_chs) -) - -x = torch.randn(16, 19, 1024) -y = block(x) # (16, 76, 1024) — spatial dim preserved via padding -``` - -Each block: Conv(k=7) → BN → ReLU → Conv(k=5) → BN → ReLU → Conv(k=3) → BN → (+shortcut) → BN → ReLU. -Shortcut uses Conv(k=1) when `in_chs != out_chs`. - ---- - -## MNet1000 / MNet_1000 - -A 2-D convolutional network originally designed for 270-channel signals at 1000 Hz. -`MNet_1000` is a deprecated alias for backward compatibility. - -```python -MNet_config = { - "classes": ["wake", "nrem", "rem"], # class labels (len = n_output_classes) - "n_chs": 270, # number of input channels - "n_fc1": 1024, # FC hidden size 1 - "d_ratio1": 0.85, # dropout probability after FC1 - "n_fc2": 256, # FC hidden size 2 - "d_ratio2": 0.85, # dropout probability after FC2 -} - -model = stx.nn.MNet1000(config=MNet_config) - -x = torch.randn(16, 270, 1000) -y = model(x) # (16, n_classes) logits -``` - -### Two-stage forward -```python -features = model.forward_bb(x) # backbone only — returns (batch, n_fc1) features -logits = model.fc(features) # apply classification head -``` - -### Utility layers (also exported) -```python -# SwapLayer — transposes axes 1 and 2 -swap = stx.nn.SwapLayer() -y = swap(x) # x.transpose(1, 2) - -# ReshapeLayer — flattens all dims except batch -reshape = stx.nn.ReshapeLayer() -y = reshape(x) # x.reshape(len(x), -1) -``` - ---- - -## BNet_v1 (BNet from _BNet.py) - -Multi-modal biosignal network. Accepts data from multiple recording modalities (e.g., MEG + EEG) with a shared backbone and per-modality heads. - -```python -BNet_config = { - "n_bands": 6, - "SAMP_RATE": 250, - "n_chs": [160, 19], # channels per modality - "n_classes": [2, 4], # output classes per modality - "n_fc1": 1024, - "d_ratio1": 0.85, - "n_fc2": 256, - "d_ratio2": 0.85, -} - -model = stx.nn.BNet_v1( - BNet_config=BNet_config, - MNet_config=stx.nn.MNet_config, # MNet_config from _MNet_1000.py -) - -x_meg = torch.randn(16, 160, 1000) -y_meg = model(x_meg, i_head=0) # use modality 0 (MEG) - -x_eeg = torch.randn(16, 19, 1000) -y_eeg = model(x_eeg, i_head=1) # use modality 1 (EEG) -``` - -### Forward pipeline -``` -z-score → DropoutChannels → FreqGainChanger -→ ChannelGainChanger[i_head] → BHead[i_head] (SpatialAttention + Conv1x1) -→ MNet1000 backbone (forward_bb) → FC head[i_head] → logits -``` - -**Warning:** The current `_BNet.py` source contains a debug `ipdb.set_trace()` call in `forward()`. - ---- - -## BNet_Res (BNet from _BNet_Res.py) - -Variant that replaces the MNet1000 backbone with `ResNetBasicBlock` blocks and average-pooling stages. - -```python -BNet_config = { - "n_bands": 6, - "n_virtual_chs": 16, - "SAMP_RATE": 250, - "n_chs_of_modalities": [160, 19], - "n_classes_of_modalities": [2, 4], - "n_fc1": 1024, - "d_ratio1": 0.85, - "n_fc2": 256, - "d_ratio2": 0.85, -} - -model = stx.nn.BNet_Res(BNet_config=BNet_config, MNet_config=stx.nn.MNet_config) -y = model(x_meg, i_head=0) -``` - -### Backbone structure (after per-modality head) -``` -BHead → blk1 → AvgPool → blk2 → AvgPool → blk3 → AvgPool → blk4 → AvgPool - → blk5 → AvgPool → blk6 → AvgPool → blk7 → AvgPool -``` -Each `blkN` is a `ResNetBasicBlock`. Channel counts halve every two blocks. - -**Warning:** The current `_BNet_Res.py` source also contains a debug `ipdb.set_trace()` call. - ---- - -## Default config objects - -```python -# MNet default config (n_chs=270) -stx.nn.MNet_config # dict from _MNet_1000.py - -# BNet v1 default config -stx.nn.BNet_config_v1 # dict from _BNet.py - -# BNet Res default config -stx.nn.BNet_config_Res # dict from _BNet_Res.py -``` diff --git a/src/scitex/nn/_skills/augmentation.md b/src/scitex/nn/_skills/augmentation.md deleted file mode 100644 index 515548c41..000000000 --- a/src/scitex/nn/_skills/augmentation.md +++ /dev/null @@ -1,130 +0,0 @@ ---- -description: Training-only augmentation layers for multi-channel biosignals — channel dropout, channel swapping, channel gain jitter, frequency band gain jitter, and axis-wise dropout. ---- - -# stx.nn — Data Augmentation Layers - -All augmentation layers are **no-ops at eval time** (i.e., they return `x` unchanged when `model.eval()` is active). They are `nn.Module` subclasses and integrate seamlessly into `nn.Sequential` pipelines. - -Input convention: `(batch_size, n_chs, seq_len)` unless noted. - ---- - -## DropoutChannels - -Replaces a random subset of channels with Gaussian noise. - -```python -import torch -import scitex as stx - -layer = stx.nn.DropoutChannels(dropout=0.5) -# dropout: float — probability that any given channel is replaced with noise - -layer.train() -x = torch.randn(16, 19, 1024) -y = layer(x) # some channels replaced with torch.randn(...) noise - -layer.eval() -y = layer(x) # identical to x, no modification -``` - -Channels selected for replacement are identified by applying `nn.Dropout` to a ones-vector, then setting those channel slots to fresh standard-normal samples on the same device as `x`. - ---- - -## SwapChannels - -Randomly permutes a subset of channels during training. - -```python -layer = stx.nn.SwapChannels(dropout=0.5) -# dropout: float — probability that any channel participates in swapping - -layer.train() -y = layer(x) # some channels shuffled among themselves -layer.eval() -y = layer(x) # x unchanged -``` - -Channels not selected by the dropout mask keep their original positions; selected channels are randomly permuted among themselves using `random.sample`. - ---- - -## ChannelGainChanger - -Applies per-channel random gain during training. - -```python -layer = stx.nn.ChannelGainChanger(n_chs=19) -# n_chs: int — must match x.shape[1] - -layer.train() -y = layer(x) -# Each channel is multiplied by a gain in [0.5, 1.5] range, -# then softmax-normalised across channels so total power is preserved. - -layer.eval() -y = layer(x) # x unchanged -``` - -Gain vector: `rand(n_chs) + 0.5`, then `softmax(gains, dim=1)`. - ---- - -## FreqGainChanger - -Splits the signal into `n_bands` frequency sub-bands using Julius, applies a random gain per band, then sums them back. - -```python -layer = stx.nn.FreqGainChanger( - n_bands=6, - samp_rate=250, - dropout_ratio=0.5, # parameter exists but not currently used in forward() -) - -layer.train() -y = layer(x) -# Internally: julius.bands.split_bands → per-band random gains (softmax-normalised) → sum - -layer.eval() -y = layer(x) # x unchanged -``` - -**Dependency:** Requires the `julius` package (`pip install julius`). - ---- - -## AxiswiseDropout - -Drops entire slices along a specified axis at training time (structured dropout). - -```python -layer = stx.nn.AxiswiseDropout( - dropout_prob=0.5, # probability of dropping a slice - dim=1, # axis to apply structured dropout on -) - -layer.train() -x = torch.randn(8, 32, 1024) -y = layer(x) -# A binary mask of shape (8, 32, 1) is generated; zero-masked channels are zeroed -# across the entire time dimension (entire channel zeroed, not individual samples) - -layer.eval() -y = layer(x) # x unchanged -``` - -Use `dim=0` to drop entire batch examples, `dim=1` for channels, `dim=-1` for time steps. - ---- - -## Summary table - -| Layer | What it randomises | Preserves shape | -|---|---|---| -| `DropoutChannels` | Channel content (replaced with noise) | Yes | -| `SwapChannels` | Channel ordering | Yes | -| `ChannelGainChanger` | Per-channel amplitude scaling | Yes | -| `FreqGainChanger` | Per-frequency-band amplitude | Yes | -| `AxiswiseDropout` | Structured zeros along one axis | Yes | diff --git a/src/scitex/nn/_skills/filters.md b/src/scitex/nn/_skills/filters.md deleted file mode 100644 index 94a73e512..000000000 --- a/src/scitex/nn/_skills/filters.md +++ /dev/null @@ -1,158 +0,0 @@ ---- -description: Fixed and differentiable FIR filter layers for 1D biosignal processing. Input shape is always (batch_size, n_chs, seq_len); output adds a filter dimension. ---- - -# stx.nn — Filters - -All filter classes inherit from `BaseFilter1D(nn.Module)` and apply FIR convolution with edge-reflection padding to avoid boundary artifacts. - -## Input / Output contract - -``` -input: (batch_size, n_chs, seq_len) -output: (batch_size, n_chs, n_filters, seq_len) -``` - -`edge_len` can be passed to `forward()` to trim transient edges from the output. -Pass `edge_len="auto"` to trim `seq_len // 8` samples from each end. - ---- - -## BandPassFilter - -Keep only the energy within specified frequency bands. - -```python -import numpy as np -import torch -import scitex as stx - -bands = np.array([[4.0, 8.0], [8.0, 13.0], [13.0, 30.0]]) # (n_bands, 2) Hz -fs = 256 # sampling rate Hz -seq_len = 1024 - -layer = stx.nn.BandPassFilter(bands=bands, fs=fs, seq_len=seq_len) -# bands: np.ndarray or torch.Tensor, shape (n_bands, 2) — [low_hz, high_hz] per row -# fs: float — sampling rate in Hz -# seq_len: int — expected input length (determines kernel length) -# fp16: bool — half precision (default False) - -x = torch.randn(8, 19, seq_len) -y = layer(x) # (8, 19, 3, 1024) -y, t = layer(x, t=time_vector) # also trims time vector if provided -``` - -### Constraints -- Each band must satisfy: `0 < low_hz < high_hz < fs/2` -- Bands are clipped to `[0.1, nyquist - 1]` automatically. - ---- - -## BandStopFilter - -Attenuate (notch) the energy within specified frequency bands. - -```python -bands = np.array([[49.0, 51.0], [99.0, 101.0]]) # 50 Hz + 100 Hz notch -layer = stx.nn.BandStopFilter(bands=bands, fs=fs, seq_len=seq_len) -# Same signature as BandPassFilter except no fp16 argument -``` - ---- - -## LowPassFilter - -Keep energy below cutoff frequencies. - -```python -cutoffs_hz = np.array([30.0, 50.0]) # shape (n_cutoffs,) — one filter per cutoff -layer = stx.nn.LowPassFilter(cutoffs_hz=cutoffs_hz, fs=fs, seq_len=seq_len) - -y = layer(x) # (batch, n_chs, 2, seq_len) -``` - ---- - -## HighPassFilter - -Keep energy above cutoff frequencies. - -```python -cutoffs_hz = np.array([1.0, 4.0]) -layer = stx.nn.HighPassFilter(cutoffs_hz=cutoffs_hz, fs=fs, seq_len=seq_len) -``` - ---- - -## GaussianFilter (from _Filters.py) - -Gaussian smoothing along the time axis. The kernel covers ± 3 standard deviations. - -```python -layer = stx.nn.GaussianFilter(sigma=5) -# sigma: int — standard deviation in samples. kernel_size = sigma * 6. -# Note: there are TWO GaussianFilter classes in the module. -# stx.nn.GaussianFilter → imported from _Filters.py (subclass of BaseFilter1D) -# _GaussianFilter.GaussianFilter → separate class with radius-based constructor -``` - -The `_Filters.py` version: -- output shape: `(batch, n_chs, 1, seq_len)` — a single filter dimension -- kernel is normalized to sum = 1 - ---- - -## DifferentiableBandPassFilter - -A learnable filter bank designed for Phase-Amplitude Coupling (PAC) pipelines. -Band center frequencies are `nn.Parameter`s that can be gradient-updated. - -```python -layer = stx.nn.DifferentiableBandPassFilter( - sig_len=1024, - fs=256, - pha_low_hz=2, # lower bound for phase-band centers - pha_high_hz=20, - pha_n_bands=30, # number of phase filters - amp_low_hz=80, # lower bound for amplitude-band centers - amp_high_hz=160, - amp_n_bands=50, # number of amplitude filters - cycle=3, # number of cycles per wavelet kernel - fp16=False, -) - -# Learnable parameters (center frequencies): -print(layer.pha_mids) # nn.Parameter, shape (pha_n_bands,) -print(layer.amp_mids) # nn.Parameter, shape (amp_n_bands,) - -y = layer(x) # (batch, n_chs, pha_n_bands + amp_n_bands, seq_len) -y.sum().backward() # gradients flow through to pha_mids / amp_mids -``` - -### Notes -- During `forward()`, `pha_mids` and `amp_mids` are clamped to their declared ranges. -- Used internally by `PAC(trainable=True)`. - ---- - -## BaseFilter1D - -Abstract base class. Extend it to add a custom filter type: - -```python -class MyFilter(stx.nn.BaseFilter1D): - def __init__(self, ...): - super().__init__(fp16=False) - kernels = ... # torch.Tensor shape (n_filters, kernel_len) - self.register_buffer("kernels", kernels) - - def init_kernels(self): - pass # required by abstractmethod; logic can live in __init__ - -# forward() is inherited; applies flip-extend padding + batch_conv -``` - -Key static helpers available on every filter: -- `BaseFilter1D.flip_extend(x, extension_length)` — reflect-pad both ends -- `BaseFilter1D.batch_conv(x, kernels, padding)` — grouped 1-D convolution over batch × channels -- `BaseFilter1D.remove_edges(x, edge_len)` — trim edge artifacts diff --git a/src/scitex/nn/_skills/pac.md b/src/scitex/nn/_skills/pac.md deleted file mode 100644 index 43fd58f1d..000000000 --- a/src/scitex/nn/_skills/pac.md +++ /dev/null @@ -1,132 +0,0 @@ ---- -description: GPU-accelerated differentiable Phase-Amplitude Coupling (PAC) and Modulation Index layers for EEG/LFP analysis. Supports static and learnable filter banks, surrogate-based z-scoring. ---- - -# stx.nn — Phase-Amplitude Coupling (PAC) - -PAC measures how the amplitude of high-frequency oscillations is modulated by the phase of low-frequency oscillations. The `PAC` module is the high-level interface; `ModulationIndex` is the underlying metric. - ---- - -## PAC - -```python -import torch -import scitex as stx - -layer = stx.nn.PAC( - seq_len=4096, # samples per segment - fs=512, # sampling rate Hz - pha_start_hz=2, # phase band lower bound Hz - pha_end_hz=20, # phase band upper bound Hz - pha_n_bands=50, # number of phase frequency bands - amp_start_hz=60, # amplitude band lower bound Hz - amp_end_hz=160, # amplitude band upper bound Hz - amp_n_bands=30, # number of amplitude frequency bands - n_perm=None, # int or None — if int, z-score against n_perm surrogates - trainable=False, # if True, use DifferentiableBandPassFilter (learnable bands) - in_place=True, - fp16=False, # half precision - amp_prob=False, # if True, return amplitude probability distribution instead of MI -) -``` - -### Input shape -``` -x: (batch_size, n_chs, seq_len) # 3D -x: (batch_size, n_chs, n_segments, seq_len) # 4D preferred -``` -3D input is automatically unsqueezed to 4D (n_segments=1). - -### Output shape -```python -pac = layer(x) -# amp_prob=False, n_perm=None: -# pac.shape: (batch_size, n_chs, pha_n_bands, amp_n_bands) -# dtype: float16 - -# amp_prob=True: -# returns amplitude probability per phase bin -# shape: (batch_size, n_chs, pha_n_bands, amp_n_bands, n_segments, n_bins=18) - -# n_perm=N (int): -# returns PAC z-scored against N cut-and-shift surrogates -# same shape as n_perm=None case -``` - -### Accessing frequency axes -```python -layer.PHA_MIDS_HZ # center frequencies for phase bands, shape (pha_n_bands,) -layer.AMP_MIDS_HZ # center frequencies for amplitude bands, shape (amp_n_bands,) -``` - -### Trainable mode (learnable band centers) -```python -layer = stx.nn.PAC(seq_len=4096, fs=512, trainable=True) -# Uses DifferentiableBandPassFilter internally -# layer.PHA_MIDS_HZ and layer.AMP_MIDS_HZ are nn.Parameter objects -# Gradients flow back through the filter centers - -pac = layer(x) -pac.sum().backward() # works -``` - -### Pipeline internals -``` -x -→ BandPassFilter (or DifferentiableBandPassFilter) - output: (batch*n_chs, n_segs, n_pha+n_amp, seq_len) -→ Hilbert - output: (batch, n_chs, n_segs, n_pha+n_amp, seq_len, 2) — [phase, amp] -→ edge trimming (seq_len // 8 from each end) -→ ModulationIndex - output: (batch, n_chs, n_pha, n_amp) -``` - ---- - -## ModulationIndex - -The Tort et al. (2010) Modulation Index metric. Used directly when you already have phase and amplitude tensors. - -```python -layer = stx.nn.ModulationIndex( - n_bins=18, # number of phase bins over [-pi, pi] - fp16=False, - amp_prob=False, # if True, return amplitude probability per bin instead of MI -) - -# Required shapes: -# pha: (batch_size, n_channels, n_freqs_pha, n_segments, seq_len) -# amp: (batch_size, n_channels, n_freqs_amp, n_segments, seq_len) - -mi = layer(pha, amp) -# mi.shape: (batch_size, n_channels, n_freqs_pha, n_freqs_amp) -# Values are averaged across n_segments dimension -``` - -### Phase bin centers -```python -layer.pha_bin_centers # numpy array, shape (n_bins,), values in [-pi, pi] -``` - -### Algorithm -1. Assign each time sample to one of `n_bins` phase bins. -2. Compute mean amplitude per bin: `amp_mean[bin] = mean(amp[pha in bin])`. -3. Normalize to probability distribution: `amp_prob = amp_mean / sum(amp_mean)`. -4. MI = `(log(n_bins) + sum(amp_prob * log(amp_prob))) / log(n_bins)` - — i.e., 1 - normalised entropy, so MI = 0 for uniform, MI = 1 for perfectly concentrated. - ---- - -## Band helper methods (PAC static methods) - -```python -# Phase bands: center ± 25% of center frequency -bands_pha = stx.nn.PAC.calc_bands_pha(start_hz=2, end_hz=20, n_bands=50) -# shape: (50, 2) - -# Amplitude bands: center ± 12.5% of center frequency -bands_amp = stx.nn.PAC.calc_bands_amp(start_hz=30, end_hz=160, n_bands=100) -# shape: (100, 2) -``` diff --git a/src/scitex/nn/_skills/spectral.md b/src/scitex/nn/_skills/spectral.md deleted file mode 100644 index 81b3cce4b..000000000 --- a/src/scitex/nn/_skills/spectral.md +++ /dev/null @@ -1,142 +0,0 @@ ---- -description: Differentiable spectral analysis layers — Hilbert transform, Spectrogram (STFT), PSD, and Morlet Wavelet transform. All are nn.Module subclasses usable inside training loops. ---- - -# stx.nn — Spectral Analysis Layers - ---- - -## Hilbert - -Computes the analytic signal via the Hilbert transform and returns instantaneous **phase** and **amplitude** as the last dimension. - -```python -import torch -import scitex as stx - -seq_len = 1024 -layer = stx.nn.Hilbert( - seq_len=seq_len, - dim=-1, # dimension along which to apply FFT (default -1) - fp16=False, # use half precision - in_place=False # if True, skips cloning the input -) - -x = torch.randn(8, 19, seq_len) -out = layer(x) -# out.shape: (8, 19, seq_len, 2) -# out[..., 0] — instantaneous phase (radians, range -pi to pi) -# out[..., 1] — instantaneous amplitude (envelope) -``` - -### Implementation details -- Uses `torch.fft.fft` / `torch.fft.ifft` for differentiability. -- Step function is approximated with `sigmoid(steepness=50 * freq)` to preserve gradients. -- Frequency buffer `f` is registered at init time (no re-allocation during forward). -- Output is always cast to float32 even when `fp16=True`. - ---- - -## PSD - -Differentiable Power Spectral Density via FFT. - -```python -layer = stx.nn.PSD( - sample_rate=256, # Hz - prob=False, # if True, normalises PSD to sum=1 (treat as probability) - dim=-1, # dimension of the time axis -) - -x = torch.randn(8, 19, 1024) -psd, freqs = layer(x) -# psd.shape: same as x with the time dimension replaced by n_freq_bins -# freqs.shape: (n_freq_bins,) — frequencies in Hz - -# Complex input uses torch.fft.fft; real input uses torch.fft.rfft -``` - -### Notes -- Normalisation: `psd = |FFT(x)|^2 / seq_len / sample_rate` -- `prob=True` divides by `psd.sum(dim)` so bins sum to 1. - ---- - -## Spectrogram - -STFT-based spectrogram over multi-channel signals. - -```python -layer = stx.nn.Spectrogram( - sampling_rate=256, # Hz - n_fft=256, # FFT size - hop_length=None, # default n_fft // 4 - win_length=None, # default n_fft - window="hann", # only "hann" is supported -) - -x = torch.randn(8, 19, 4096) # (batch, n_chs, seq_len) -specs, freqs, times_sec = layer(x) -# specs.shape: (batch, n_chs, n_fft//2 + 1, n_frames) — magnitude spectrogram -# freqs.shape: (n_fft//2 + 1,) — Hz -# times_sec.shape: (n_frames,) — seconds -``` - -### Convenience function -```python -from scitex.nn._Spectrogram import spectrograms - -specs, freqs, times_sec = spectrograms(x, fs=256, cuda=True) -# Wraps Spectrogram(fs) in a @torch_fn decorator — accepts numpy or torch input -``` - ---- - -## Wavelet - -Morlet continuous wavelet transform up to the Nyquist frequency. -Returns phase, log-amplitude (or amplitude), and frequency axes. - -```python -layer = stx.nn.Wavelet( - samp_rate=256, - kernel_size=None, # default = samp_rate samples - freq_scale="linear", # "linear" or "log" — how frequency bins are spaced - out_scale="log", # "log" applies log(amp + 1e-5); anything else returns raw amp -) - -x = torch.randn(8, 19, 4096) # (batch, n_chs, seq_len) -pha, amp, freqs = layer(x) -# pha.shape: (batch, n_chs, n_freqs, seq_len) — instantaneous phase (radians) -# amp.shape: (batch, n_chs, n_freqs, seq_len) — log-amplitude or raw amplitude -# freqs.shape: (batch, n_chs, n_freqs) — center frequency per filter (Hz) -``` - -### Frequency spacing -| `freq_scale` | Number of filters | Spacing | -|---|---|---| -| `"linear"` | `int(nyquist)` | 1 Hz steps up to Nyquist | -| `"log"` | `floor(log2(nyquist))` | Powers of 2 up to Nyquist | - -### Implementation -- Kernels are complex Morlet wavelets (`sigma = 7 / (2π * center_freq)`). -- Real and imaginary parts convolved separately, then combined via `torch.view_as_complex`. -- Edge-reflection padding by `radius = kernel_size // 2` samples. - ---- - -## Utility functions (from _Spectrogram.py) - -```python -from scitex.nn._Spectrogram import my_softmax, unbias, normalize - -# my_softmax — softmax along a dimension -y = my_softmax(x, dim=-1) # @torch_fn decorated - -# unbias — subtract min or mean along a dimension -y = unbias(x, func="min", dim=-1) # func: "min" or "mean" -y = unbias(x, func="mean", dim=-1, cuda=True) - -# normalize — scale by max absolute value -y = normalize(x, axis=-1, amp=1.0) -``` diff --git a/src/scitex/nn/_skills/utility-layers.md b/src/scitex/nn/_skills/utility-layers.md deleted file mode 100644 index 4d6db2473..000000000 --- a/src/scitex/nn/_skills/utility-layers.md +++ /dev/null @@ -1,96 +0,0 @@ ---- -description: Lightweight nn.Module helpers — SpatialAttention, TransposeLayer, and the standalone GaussianFilter (_GaussianFilter.py). These are building blocks for larger architectures. ---- - -# stx.nn — Utility Layers - ---- - -## SpatialAttention - -Computes a learned spatial (channel-wise) attention weight and multiplies the input by it. Used inside `BHead` in the BNet architectures. - -```python -import torch -import scitex as stx - -layer = stx.nn.SpatialAttention(n_chs_in=19) -# n_chs_in: int — number of input channels - -x = torch.randn(8, 19, 1024) # (batch, n_chs, seq_len) -y = layer(x) # (8, 19, 1024) — same shape as input - -# Mechanism: -# 1. AdaptiveAvgPool1d(1) → (batch, n_chs, 1) — global average per channel -# 2. Conv1d(n_chs_in, 1, kernel_size=1) → (batch, 1, 1) — scalar weight -# 3. Return scalar_weight * x_orig (broadcast over time dim) -``` - -This is a simplified attention that produces a single scalar gate rather than per-channel weights. - ---- - -## TransposeLayer - -Wraps `torch.Tensor.transpose` as an `nn.Module` so it can be used inside `nn.Sequential`. - -```python -layer = stx.nn.TransposeLayer(axis1=1, axis2=2) -# axis1, axis2: int — the two dimensions to swap - -x = torch.randn(8, 19, 1024) -y = layer(x) # (8, 1024, 19) -``` - ---- - -## GaussianFilter (from _GaussianFilter.py) - -A separate, radius-based Gaussian smoothing layer (different from `GaussianFilter` in `_Filters.py`). - -```python -layer = stx.nn.GaussianFilter # NOTE: this name resolves to _Filters.py's version - # (imported last in __init__.py) - -# To use _GaussianFilter.py's version directly: -from scitex.nn._GaussianFilter import GaussianFilter as GaussianFilterRadius - -layer = GaussianFilterRadius( - radius=5, # half-width in samples; kernel_size = 2 * radius + 1 - sigma=None, # if None, sigma = radius / 2 -) - -x = torch.randn(8, 19, 1024) -y = layer(x) # same shape as input (padding=radius preserves length) - -# Works on 1D, 2D, or 3D inputs: -# 1D: unsqueezed to (1, 1, seq_len) -# 2D: unsqueezed to (batch, 1, seq_len) -# 3D: (batch, n_chs, seq_len) — applied with grouped convolution -``` - -### Comparison of the two GaussianFilter classes - -| | `_Filters.GaussianFilter` | `_GaussianFilter.GaussianFilter` | -|---|---|---| -| Constructor | `GaussianFilter(sigma)` | `GaussianFilter(radius, sigma=None)` | -| Kernel size | `sigma * 6` | `2 * radius + 1` | -| Output shape | `(batch, n_chs, 1, seq_len)` — adds filter dim | `(batch, n_chs, seq_len)` — preserves shape | -| Exported as | `stx.nn.GaussianFilter` | Must import from `_GaussianFilter` directly | -| Normalisation | Sum = 1 | Normalised Gaussian PDF then divided by sum | - ---- - -## SwapLayer and ReshapeLayer - -These are internal helper layers used inside `MNet1000`. They are also exported from `stx.nn`. - -```python -# SwapLayer — identical to TransposeLayer(1, 2) -swap = stx.nn.SwapLayer() -y = swap(x) # x.transpose(1, 2) - -# ReshapeLayer — flatten all dims except batch -reshape = stx.nn.ReshapeLayer() -y = reshape(x) # x.reshape(len(x), -1) -```