From 280d4b146927201f1035a88314957f1adb60882e Mon Sep 17 00:00:00 2001 From: Chetan Gohil Date: Fri, 17 Apr 2026 20:16:49 +0100 Subject: [PATCH 1/5] add: generic h5 dataset; rename CamcanGlasserDataModule -> EphysDataModule H5Session/H5Dataset/build_h5_dataset load windowed continuous signals from per-session h5 files and expose the .dataset (ConcatDataset) + .subject interface that the DataModule already consumed, so they plug in as a drop-in replacement for pnpl.datasets.CamcanGlasser. The module (and its collate fn) were always dataset-agnostic; renamed to reflect that. --- ephys_tokenizer/data/dataloader.py | 212 ++++++++++++++++++++++++++++- examples/train_etkn.py | 4 +- 2 files changed, 207 insertions(+), 9 deletions(-) diff --git a/ephys_tokenizer/data/dataloader.py b/ephys_tokenizer/data/dataloader.py index 91ce6e1..1dc6e08 100644 --- a/ephys_tokenizer/data/dataloader.py +++ b/ephys_tokenizer/data/dataloader.py @@ -1,12 +1,22 @@ -"""PyTorch DataLoader and Lightning DataModule for tokenizer training.""" +"""PyTorch DataLoader and Lightning DataModule for tokenizer training. + +Also provides a generic h5-backed ``Dataset`` (``H5Session`` / ``H5Dataset`` / +``build_h5_dataset``) that plugs into :class:`EphysDataModule` as a drop-in +replacement for ``pnpl.datasets.CamcanGlasser``. +""" # Import packages from __future__ import annotations +import os +import random + +import h5py import numpy as np +import pandas as pd import pytorch_lightning as pl -import random import torch +from pathlib import Path from torch.utils.data import ( ConcatDataset, DataLoader, @@ -45,9 +55,9 @@ def _default_worker_init_fn(worker_id: int) -> None: torch.manual_seed(seed) -def _collate_camcan(batch: Sequence[dict]) -> Dict[str, Any]: +def _collate_default(batch: Sequence[dict]) -> Dict[str, Any]: """ - Collate function for CamcanGlasser items. + Default collate function for session-window items. Each item is expected to be: {"data": np.ndarray (L, C), "times": np.ndarray (L,), "info": dict} @@ -84,7 +94,7 @@ def _make_dataloader( persistent_workers: Optional[bool] = True, drop_last: Optional[bool] = False, sampler: Optional[Sampler] = None, - collate_fn=_collate_camcan, + collate_fn=_collate_default, ) -> DataLoader: """ Creates a DataLoader for a given dataset. @@ -130,9 +140,13 @@ def _make_dataloader( ) -class CamcanGlasserDataModule(pl.LightningDataModule): +class EphysDataModule(pl.LightningDataModule): """ - Lightning DataModule for the CamCAN dataset. + Lightning DataModule for continuous session-based datasets. + + Expects the input dataset to wrap a ``ConcatDataset`` of per-session + sub-datasets (accessible via ``dataset.dataset``). Each sub-dataset must + expose a ``.subject`` attribute if subject-level splitting is used. Parameters ---------- @@ -441,3 +455,187 @@ def full_dataloader(self) -> DataLoader: persistent_workers=self.persistent_workers, drop_last=False, ) + + +# --------------------------------------------------------------------------- +# Generic h5-backed Dataset +# --------------------------------------------------------------------------- +# +# Mirrors the interface consumed by :class:`EphysDataModule`: +# - Top-level object exposes ``.dataset`` as a ``ConcatDataset`` of +# per-session sub-datasets. +# - Each sub-dataset has a ``.subject`` attribute for subject-level splits. +# - Each item is ``{"data": (L, C) float32, "times": (L,) float32, +# "info": dict}``. +# +# Each h5 file is expected to hold a single ``"data"`` dataset of shape +# ``(n_samples, n_channels)``. + + +class H5Session(Dataset): + """Non-overlapping windowed view of one h5 session file. + + Parameters + ---------- + h5_path : str + Path to a session h5 file containing a single ``"data"`` dataset of + shape ``(n_samples, n_channels)``. + window_len : int + Window length in samples; windows are non-overlapping. Any trailing + samples shorter than ``window_len`` are dropped. + sfreq : float + Sample rate in Hz, used only to populate the ``times`` array. + info : dict + Metadata dict copied into every item's ``info`` field. Must contain a + ``"subject"`` key if subject-level splitting will be used downstream. + standardize : bool + If True, compute per-channel mean/std across the full session at + construction time and apply z-score normalisation in ``__getitem__``. + """ + + def __init__( + self, + h5_path: str, + window_len: int, + sfreq: float, + info: Dict[str, Any], + standardize: bool = True, + ): + self.h5_path = str(h5_path) + self.window_len = int(window_len) + self.sfreq = float(sfreq) + self.info = dict(info) + self.standardize = bool(standardize) + + # Required by EphysDataModule subject-splitting. + self.subject = self.info.get("subject") + + with h5py.File(self.h5_path, "r") as f: + ds = f["data"] + self.n_samples = int(ds.shape[0]) + self.n_channels = int(ds.shape[1]) + if self.standardize: + data = ds[...].astype(np.float64) + self._mean = data.mean(axis=0).astype(np.float32) + std = data.std(axis=0) + std[std < 1e-8] = 1.0 + self._std = std.astype(np.float32) + else: + self._mean = None + self._std = None + + self.n_windows = self.n_samples // self.window_len + # Opened lazily per process; re-opened if the PID changes so a handle + # opened in the main process isn't reused (and silently desynced) by a + # forked DataLoader worker. + self._h5: Optional[h5py.File] = None + self._h5_pid: Optional[int] = None + + def __len__(self) -> int: + return self.n_windows + + def __getstate__(self) -> Dict[str, Any]: + # Strip the handle so the dataset pickles cleanly into workers + # (spawn multiprocessing context). + state = self.__dict__.copy() + state["_h5"] = None + state["_h5_pid"] = None + return state + + def _handle(self) -> h5py.File: + pid = os.getpid() + if self._h5 is None or self._h5_pid != pid: + self._h5 = h5py.File(self.h5_path, "r") + self._h5_pid = pid + return self._h5 + + def __getitem__(self, idx: int) -> Dict[str, Any]: + if idx < 0 or idx >= self.n_windows: + raise IndexError(idx) + start = idx * self.window_len + end = start + self.window_len + data = self._handle()["data"][start:end, :].astype(np.float32, copy=False) + if self.standardize: + data = (data - self._mean) / self._std + times = np.arange(start, end, dtype=np.float32) / self.sfreq + return {"data": data, "times": times, "info": self.info} + + +class H5Dataset(Dataset): + """Concatenation of :class:`H5Session` datasets. + + Exposes ``.dataset`` as a ``ConcatDataset`` so it is a drop-in replacement + for ``pnpl.datasets.CamcanGlasser`` when used with + :class:`EphysDataModule`. + """ + + def __init__(self, sessions: Sequence[H5Session]): + if len(sessions) == 0: + raise ValueError("H5Dataset requires at least one session.") + self.dataset = ConcatDataset(list(sessions)) + + def __len__(self) -> int: + return len(self.dataset) + + def __getitem__(self, idx: int) -> Dict[str, Any]: + return self.dataset[idx] + + +def build_h5_dataset( + sessions_csv: str, + h5_dir: str, + window_len: int, + sfreq: float = 250.0, + standardize: bool = True, + include_sessions: Optional[Sequence[str]] = None, + info_cols: Sequence[str] = ( + "session", + "dataset", + "subject", + "task", + "system", + "age", + "sex", + ), +) -> H5Dataset: + """Build an :class:`H5Dataset` from a sessions CSV. + + Parameters + ---------- + sessions_csv : str + Path to a CSV indexing h5 sessions. Must contain a ``session`` column + and columns named in ``info_cols``. + h5_dir : str + Directory containing ``{session}.h5`` files. + window_len : int + Number of samples per window. + sfreq : float + Sample rate in Hz (default 250 Hz). + standardize : bool + Apply per-session, per-channel z-score normalisation. + include_sessions : Optional[Sequence[str]] + If given, restrict to these session IDs. + info_cols : Sequence[str] + Columns copied from the CSV into each item's ``info`` dict. + """ + df = pd.read_csv(sessions_csv) + if include_sessions is not None: + df = df[df["session"].isin(set(include_sessions))] + if df.empty: + raise ValueError("No sessions selected from CSV.") + + h5_dir_path = Path(h5_dir) + sessions: List[H5Session] = [] + for _, row in df.iterrows(): + info = {c: row[c] for c in info_cols if c in row} + h5_path = h5_dir_path / f"{row['session']}.h5" + sessions.append( + H5Session( + h5_path=str(h5_path), + window_len=window_len, + sfreq=sfreq, + info=info, + standardize=standardize, + ) + ) + return H5Dataset(sessions) diff --git a/examples/train_etkn.py b/examples/train_etkn.py index 3f8dc84..cb91361 100644 --- a/examples/train_etkn.py +++ b/examples/train_etkn.py @@ -11,7 +11,7 @@ from pytorch_lightning.loggers import CSVLogger from ephys_tokenizer.configs import get_config -from ephys_tokenizer.data.dataloader import CamcanGlasserDataModule +from ephys_tokenizer.data.dataloader import EphysDataModule from ephys_tokenizer.models import callbacks from ephys_tokenizer.models.ephys_tokenizer import EphysTokenizerModule from ephys_tokenizer.utils import plotting @@ -71,7 +71,7 @@ def main(cfg: DictConfig): include_subjects=subject_ids, verbose=False, ) - camcan_datamodule = CamcanGlasserDataModule( + camcan_datamodule = EphysDataModule( dataset=camcan_data, batch_size=batch_size, val_split=0, From 5e2a2d2ed5140fd5c5e25b9e5347bb6af8986f84 Mon Sep 17 00:00:00 2001 From: Chetan Gohil Date: Sat, 18 Apr 2026 08:36:35 +0100 Subject: [PATCH 2/5] enhance: robust per-session standardisation for H5Session MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace mean/std with median + 1.4826·MAD. For Gaussian data the two are asymptotically identical (1.4826 is the MAD→σ consistency factor), but MAD is unaffected by localised high-amplitude artefacts whose inflated σ otherwise scales the rest of the session down and destroys reconstruction quality on long or noisy recordings. --- ephys_tokenizer/data/dataloader.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/ephys_tokenizer/data/dataloader.py b/ephys_tokenizer/data/dataloader.py index 1dc6e08..2038b08 100644 --- a/ephys_tokenizer/data/dataloader.py +++ b/ephys_tokenizer/data/dataloader.py @@ -489,8 +489,9 @@ class H5Session(Dataset): Metadata dict copied into every item's ``info`` field. Must contain a ``"subject"`` key if subject-level splitting will be used downstream. standardize : bool - If True, compute per-channel mean/std across the full session at - construction time and apply z-score normalisation in ``__getitem__``. + If True, apply per-channel, per-session standardisation using the + outlier-robust estimator median + 1.4826·MAD (consistent with mean/std + for Gaussian data, but unaffected by localised artefacts). """ def __init__( @@ -516,10 +517,11 @@ def __init__( self.n_channels = int(ds.shape[1]) if self.standardize: data = ds[...].astype(np.float64) - self._mean = data.mean(axis=0).astype(np.float32) - std = data.std(axis=0) - std[std < 1e-8] = 1.0 - self._std = std.astype(np.float32) + centre = np.median(data, axis=0) + scale = 1.4826 * np.median(np.abs(data - centre), axis=0) + scale[scale < 1e-8] = 1.0 + self._mean = centre.astype(np.float32) + self._std = scale.astype(np.float32) else: self._mean = None self._std = None @@ -612,7 +614,8 @@ def build_h5_dataset( sfreq : float Sample rate in Hz (default 250 Hz). standardize : bool - Apply per-session, per-channel z-score normalisation. + Apply per-session, per-channel robust standardisation + (median + 1.4826·MAD). include_sessions : Optional[Sequence[str]] If given, restrict to these session IDs. info_cols : Sequence[str] From 31c2b25137ebab4d9494c9c48fd0f0a8747abe0c Mon Sep 17 00:00:00 2001 From: Chetan Gohil Date: Sat, 18 Apr 2026 08:37:09 +0100 Subject: [PATCH 3/5] fix: stream per-session stats in get_pve instead of preallocating full dataset MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The old implementation preallocated two (n_total_sequences, L, C) float32 tensors on the compute device, which is fine for a ~50-subject CamCAN subset but OOMs on larger datasets — e.g. ~24 GB combined for 277k windows at L=200, C=54. Replace with a streaming pass that accumulates per-session sums of squared error and squared total on the host (O(n_sessions) memory). np.searchsorted maps each batch element to its session in O(log n). --- ephys_tokenizer/models/ephys_tokenizer.py | 53 +++++++++-------------- 1 file changed, 21 insertions(+), 32 deletions(-) diff --git a/ephys_tokenizer/models/ephys_tokenizer.py b/ephys_tokenizer/models/ephys_tokenizer.py index 3cb2152..6b1b8d0 100644 --- a/ephys_tokenizer/models/ephys_tokenizer.py +++ b/ephys_tokenizer/models/ephys_tokenizer.py @@ -652,12 +652,13 @@ def get_pve( # Put model in evaluation mode model = self.model.to(device).eval() - sequence_length = self.config.sequence_length - n_channels = self.config.n_channels - # Preallocate arrays for tokens (and weights if requested) - original_x = torch.empty((n_total_sequences, sequence_length, n_channels), dtype=torch.float32, device=device) - reconstructed_x = torch.empty((n_total_sequences, sequence_length, n_channels), dtype=torch.float32, device=device) + # Stream per-session sum-of-squared-error and sum-of-squared-total + # through the dataloader + n_sessions = len(ranges) + sess_sse = np.zeros(n_sessions, dtype=np.float64) + sess_sst = np.zeros(n_sessions, dtype=np.float64) + range_starts = np.array([r[1] for r in ranges], dtype=np.int64) idx = 0 with torch.inference_mode(): @@ -668,37 +669,25 @@ def get_pve( _, rx, _ = model(x) # shape: (B, L, C) - bsz = x.shape[0] # actual batch size (last batch may be smaller) - original_x[idx:idx + bsz].copy_(x) - reconstructed_x[idx:idx + bsz].copy_(rx) + bsz = x.shape[0] + sse_b = ((x - rx) ** 2).sum(dim=(1, 2)).cpu().numpy() # (B,) + sst_b = (x ** 2).sum(dim=(1, 2)).cpu().numpy() # (B,) + + # Map each window to its session index (ranges are contiguous) + positions = np.arange(idx, idx + bsz, dtype=np.int64) + sess_ix = np.searchsorted(range_starts, positions, side="right") - 1 + np.add.at(sess_sse, sess_ix, sse_b) + np.add.at(sess_sst, sess_ix, sst_b) idx += bsz - # Move all tensors to CPU at once - original_x = original_x.cpu().numpy() - reconstructed_x = reconstructed_x.cpu().numpy() - - # Split by subject ranges - all_original_x = [] - all_reconstructed_x = [] - for _, start, end in ranges: - all_original_x.append(original_x[start:end].reshape(-1, n_channels)) - all_reconstructed_x.append(reconstructed_x[start:end].reshape(-1, n_channels)) - # all_*_x.shape: (N, T, C) - - pve = [] - for x, rx in tqdm( - zip(all_original_x, all_reconstructed_x), - desc="Calculating Percentage of Variance Explained ...", - total=len(all_original_x), - ): - pve.append( - 100 * (1 - np.sum((x - rx) ** 2) / np.sum(x ** 2)) - ) + with np.errstate(divide="ignore", invalid="ignore"): + pve = 100.0 * (1.0 - sess_sse / sess_sst) + pve = np.where(sess_sst > 0, pve, 0.0) - if len(pve) == 1: - return pve[0] - return np.array(pve) + if n_sessions == 1: + return float(pve[0]) + return pve def get_token_kernel_response( self, From 7a5225816d322d33e9b63a1a102ef095695dc4f8 Mon Sep 17 00:00:00 2001 From: Chetan Gohil Date: Sat, 18 Apr 2026 09:48:37 +0100 Subject: [PATCH 4/5] Increased tokeniser figure size. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index de9c16c..ff6c871 100644 --- a/README.md +++ b/README.md @@ -3,7 +3,7 @@ **EphysTokenizer** is a data-driven, sample-level tokenizer for non-invasive human electrophysiological signals (MEG/EEG). It discretizes continuous neural time series into integer token sequences at each time step. By training an autoencoder with an RNN-based encoder and a convolutional decoder, the model learns a quantization scheme through signal reconstruction, enabling end-to-end tokenization directly from raw time-domain samples.
- EphysTokenizer Overview + EphysTokenizer Overview

Overview of the EphysTokenizer Architecture

From 03a644e7bdd07dcf2c8292b63f2620e0475f60ba Mon Sep 17 00:00:00 2001 From: Chetan Gohil Date: Sat, 18 Apr 2026 20:11:34 +0100 Subject: [PATCH 5/5] revert: H5Session standardisation back to z-score Outlier sessions are now excluded upstream via the curated subset, so robust MAD estimation is no longer needed. --- ephys_tokenizer/data/dataloader.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/ephys_tokenizer/data/dataloader.py b/ephys_tokenizer/data/dataloader.py index 2038b08..ea02cb1 100644 --- a/ephys_tokenizer/data/dataloader.py +++ b/ephys_tokenizer/data/dataloader.py @@ -489,9 +489,8 @@ class H5Session(Dataset): Metadata dict copied into every item's ``info`` field. Must contain a ``"subject"`` key if subject-level splitting will be used downstream. standardize : bool - If True, apply per-channel, per-session standardisation using the - outlier-robust estimator median + 1.4826·MAD (consistent with mean/std - for Gaussian data, but unaffected by localised artefacts). + If True, apply per-channel, per-session z-score standardisation + (subtract mean, divide by standard deviation). """ def __init__( @@ -517,11 +516,8 @@ def __init__( self.n_channels = int(ds.shape[1]) if self.standardize: data = ds[...].astype(np.float64) - centre = np.median(data, axis=0) - scale = 1.4826 * np.median(np.abs(data - centre), axis=0) - scale[scale < 1e-8] = 1.0 - self._mean = centre.astype(np.float32) - self._std = scale.astype(np.float32) + self._mean = data.mean(axis=0).astype(np.float32) + self._std = data.std(axis=0).astype(np.float32) else: self._mean = None self._std = None @@ -614,8 +610,7 @@ def build_h5_dataset( sfreq : float Sample rate in Hz (default 250 Hz). standardize : bool - Apply per-session, per-channel robust standardisation - (median + 1.4826·MAD). + Apply per-session, per-channel z-score standardisation. include_sessions : Optional[Sequence[str]] If given, restrict to these session IDs. info_cols : Sequence[str]