diff --git a/README.md b/README.md
index de9c16c..ff6c871 100644
--- a/README.md
+++ b/README.md
@@ -3,7 +3,7 @@
**EphysTokenizer** is a data-driven, sample-level tokenizer for non-invasive human electrophysiological signals (MEG/EEG). It discretizes continuous neural time series into integer token sequences at each time step. By training an autoencoder with an RNN-based encoder and a convolutional decoder, the model learns a quantization scheme through signal reconstruction, enabling end-to-end tokenization directly from raw time-domain samples.
-

+
Overview of the EphysTokenizer Architecture
diff --git a/ephys_tokenizer/data/dataloader.py b/ephys_tokenizer/data/dataloader.py
index 91ce6e1..ea02cb1 100644
--- a/ephys_tokenizer/data/dataloader.py
+++ b/ephys_tokenizer/data/dataloader.py
@@ -1,12 +1,22 @@
-"""PyTorch DataLoader and Lightning DataModule for tokenizer training."""
+"""PyTorch DataLoader and Lightning DataModule for tokenizer training.
+
+Also provides a generic h5-backed ``Dataset`` (``H5Session`` / ``H5Dataset`` /
+``build_h5_dataset``) that plugs into :class:`EphysDataModule` as a drop-in
+replacement for ``pnpl.datasets.CamcanGlasser``.
+"""
# Import packages
from __future__ import annotations
+import os
+import random
+
+import h5py
import numpy as np
+import pandas as pd
import pytorch_lightning as pl
-import random
import torch
+from pathlib import Path
from torch.utils.data import (
ConcatDataset,
DataLoader,
@@ -45,9 +55,9 @@ def _default_worker_init_fn(worker_id: int) -> None:
torch.manual_seed(seed)
-def _collate_camcan(batch: Sequence[dict]) -> Dict[str, Any]:
+def _collate_default(batch: Sequence[dict]) -> Dict[str, Any]:
"""
- Collate function for CamcanGlasser items.
+ Default collate function for session-window items.
Each item is expected to be:
{"data": np.ndarray (L, C), "times": np.ndarray (L,), "info": dict}
@@ -84,7 +94,7 @@ def _make_dataloader(
persistent_workers: Optional[bool] = True,
drop_last: Optional[bool] = False,
sampler: Optional[Sampler] = None,
- collate_fn=_collate_camcan,
+ collate_fn=_collate_default,
) -> DataLoader:
"""
Creates a DataLoader for a given dataset.
@@ -130,9 +140,13 @@ def _make_dataloader(
)
-class CamcanGlasserDataModule(pl.LightningDataModule):
+class EphysDataModule(pl.LightningDataModule):
"""
- Lightning DataModule for the CamCAN dataset.
+ Lightning DataModule for continuous session-based datasets.
+
+ Expects the input dataset to wrap a ``ConcatDataset`` of per-session
+ sub-datasets (accessible via ``dataset.dataset``). Each sub-dataset must
+ expose a ``.subject`` attribute if subject-level splitting is used.
Parameters
----------
@@ -441,3 +455,185 @@ def full_dataloader(self) -> DataLoader:
persistent_workers=self.persistent_workers,
drop_last=False,
)
+
+
+# ---------------------------------------------------------------------------
+# Generic h5-backed Dataset
+# ---------------------------------------------------------------------------
+#
+# Mirrors the interface consumed by :class:`EphysDataModule`:
+# - Top-level object exposes ``.dataset`` as a ``ConcatDataset`` of
+# per-session sub-datasets.
+# - Each sub-dataset has a ``.subject`` attribute for subject-level splits.
+# - Each item is ``{"data": (L, C) float32, "times": (L,) float32,
+# "info": dict}``.
+#
+# Each h5 file is expected to hold a single ``"data"`` dataset of shape
+# ``(n_samples, n_channels)``.
+
+
+class H5Session(Dataset):
+ """Non-overlapping windowed view of one h5 session file.
+
+ Parameters
+ ----------
+ h5_path : str
+ Path to a session h5 file containing a single ``"data"`` dataset of
+ shape ``(n_samples, n_channels)``.
+ window_len : int
+ Window length in samples; windows are non-overlapping. Any trailing
+ samples shorter than ``window_len`` are dropped.
+ sfreq : float
+ Sample rate in Hz, used only to populate the ``times`` array.
+ info : dict
+ Metadata dict copied into every item's ``info`` field. Must contain a
+ ``"subject"`` key if subject-level splitting will be used downstream.
+ standardize : bool
+ If True, apply per-channel, per-session z-score standardisation
+ (subtract mean, divide by standard deviation).
+ """
+
+ def __init__(
+ self,
+ h5_path: str,
+ window_len: int,
+ sfreq: float,
+ info: Dict[str, Any],
+ standardize: bool = True,
+ ):
+ self.h5_path = str(h5_path)
+ self.window_len = int(window_len)
+ self.sfreq = float(sfreq)
+ self.info = dict(info)
+ self.standardize = bool(standardize)
+
+ # Required by EphysDataModule subject-splitting.
+ self.subject = self.info.get("subject")
+
+ with h5py.File(self.h5_path, "r") as f:
+ ds = f["data"]
+ self.n_samples = int(ds.shape[0])
+ self.n_channels = int(ds.shape[1])
+ if self.standardize:
+ data = ds[...].astype(np.float64)
+ self._mean = data.mean(axis=0).astype(np.float32)
+ self._std = data.std(axis=0).astype(np.float32)
+ else:
+ self._mean = None
+ self._std = None
+
+ self.n_windows = self.n_samples // self.window_len
+ # Opened lazily per process; re-opened if the PID changes so a handle
+ # opened in the main process isn't reused (and silently desynced) by a
+ # forked DataLoader worker.
+ self._h5: Optional[h5py.File] = None
+ self._h5_pid: Optional[int] = None
+
+ def __len__(self) -> int:
+ return self.n_windows
+
+ def __getstate__(self) -> Dict[str, Any]:
+ # Strip the handle so the dataset pickles cleanly into workers
+ # (spawn multiprocessing context).
+ state = self.__dict__.copy()
+ state["_h5"] = None
+ state["_h5_pid"] = None
+ return state
+
+ def _handle(self) -> h5py.File:
+ pid = os.getpid()
+ if self._h5 is None or self._h5_pid != pid:
+ self._h5 = h5py.File(self.h5_path, "r")
+ self._h5_pid = pid
+ return self._h5
+
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
+ if idx < 0 or idx >= self.n_windows:
+ raise IndexError(idx)
+ start = idx * self.window_len
+ end = start + self.window_len
+ data = self._handle()["data"][start:end, :].astype(np.float32, copy=False)
+ if self.standardize:
+ data = (data - self._mean) / self._std
+ times = np.arange(start, end, dtype=np.float32) / self.sfreq
+ return {"data": data, "times": times, "info": self.info}
+
+
+class H5Dataset(Dataset):
+ """Concatenation of :class:`H5Session` datasets.
+
+ Exposes ``.dataset`` as a ``ConcatDataset`` so it is a drop-in replacement
+ for ``pnpl.datasets.CamcanGlasser`` when used with
+ :class:`EphysDataModule`.
+ """
+
+ def __init__(self, sessions: Sequence[H5Session]):
+ if len(sessions) == 0:
+ raise ValueError("H5Dataset requires at least one session.")
+ self.dataset = ConcatDataset(list(sessions))
+
+ def __len__(self) -> int:
+ return len(self.dataset)
+
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
+ return self.dataset[idx]
+
+
+def build_h5_dataset(
+ sessions_csv: str,
+ h5_dir: str,
+ window_len: int,
+ sfreq: float = 250.0,
+ standardize: bool = True,
+ include_sessions: Optional[Sequence[str]] = None,
+ info_cols: Sequence[str] = (
+ "session",
+ "dataset",
+ "subject",
+ "task",
+ "system",
+ "age",
+ "sex",
+ ),
+) -> H5Dataset:
+ """Build an :class:`H5Dataset` from a sessions CSV.
+
+ Parameters
+ ----------
+ sessions_csv : str
+ Path to a CSV indexing h5 sessions. Must contain a ``session`` column
+ and columns named in ``info_cols``.
+ h5_dir : str
+ Directory containing ``{session}.h5`` files.
+ window_len : int
+ Number of samples per window.
+ sfreq : float
+ Sample rate in Hz (default 250 Hz).
+ standardize : bool
+ Apply per-session, per-channel z-score standardisation.
+ include_sessions : Optional[Sequence[str]]
+ If given, restrict to these session IDs.
+ info_cols : Sequence[str]
+ Columns copied from the CSV into each item's ``info`` dict.
+ """
+ df = pd.read_csv(sessions_csv)
+ if include_sessions is not None:
+ df = df[df["session"].isin(set(include_sessions))]
+ if df.empty:
+ raise ValueError("No sessions selected from CSV.")
+
+ h5_dir_path = Path(h5_dir)
+ sessions: List[H5Session] = []
+ for _, row in df.iterrows():
+ info = {c: row[c] for c in info_cols if c in row}
+ h5_path = h5_dir_path / f"{row['session']}.h5"
+ sessions.append(
+ H5Session(
+ h5_path=str(h5_path),
+ window_len=window_len,
+ sfreq=sfreq,
+ info=info,
+ standardize=standardize,
+ )
+ )
+ return H5Dataset(sessions)
diff --git a/ephys_tokenizer/models/ephys_tokenizer.py b/ephys_tokenizer/models/ephys_tokenizer.py
index 3cb2152..6b1b8d0 100644
--- a/ephys_tokenizer/models/ephys_tokenizer.py
+++ b/ephys_tokenizer/models/ephys_tokenizer.py
@@ -652,12 +652,13 @@ def get_pve(
# Put model in evaluation mode
model = self.model.to(device).eval()
- sequence_length = self.config.sequence_length
- n_channels = self.config.n_channels
- # Preallocate arrays for tokens (and weights if requested)
- original_x = torch.empty((n_total_sequences, sequence_length, n_channels), dtype=torch.float32, device=device)
- reconstructed_x = torch.empty((n_total_sequences, sequence_length, n_channels), dtype=torch.float32, device=device)
+ # Stream per-session sum-of-squared-error and sum-of-squared-total
+ # through the dataloader
+ n_sessions = len(ranges)
+ sess_sse = np.zeros(n_sessions, dtype=np.float64)
+ sess_sst = np.zeros(n_sessions, dtype=np.float64)
+ range_starts = np.array([r[1] for r in ranges], dtype=np.int64)
idx = 0
with torch.inference_mode():
@@ -668,37 +669,25 @@ def get_pve(
_, rx, _ = model(x) # shape: (B, L, C)
- bsz = x.shape[0] # actual batch size (last batch may be smaller)
- original_x[idx:idx + bsz].copy_(x)
- reconstructed_x[idx:idx + bsz].copy_(rx)
+ bsz = x.shape[0]
+ sse_b = ((x - rx) ** 2).sum(dim=(1, 2)).cpu().numpy() # (B,)
+ sst_b = (x ** 2).sum(dim=(1, 2)).cpu().numpy() # (B,)
+
+ # Map each window to its session index (ranges are contiguous)
+ positions = np.arange(idx, idx + bsz, dtype=np.int64)
+ sess_ix = np.searchsorted(range_starts, positions, side="right") - 1
+ np.add.at(sess_sse, sess_ix, sse_b)
+ np.add.at(sess_sst, sess_ix, sst_b)
idx += bsz
- # Move all tensors to CPU at once
- original_x = original_x.cpu().numpy()
- reconstructed_x = reconstructed_x.cpu().numpy()
-
- # Split by subject ranges
- all_original_x = []
- all_reconstructed_x = []
- for _, start, end in ranges:
- all_original_x.append(original_x[start:end].reshape(-1, n_channels))
- all_reconstructed_x.append(reconstructed_x[start:end].reshape(-1, n_channels))
- # all_*_x.shape: (N, T, C)
-
- pve = []
- for x, rx in tqdm(
- zip(all_original_x, all_reconstructed_x),
- desc="Calculating Percentage of Variance Explained ...",
- total=len(all_original_x),
- ):
- pve.append(
- 100 * (1 - np.sum((x - rx) ** 2) / np.sum(x ** 2))
- )
+ with np.errstate(divide="ignore", invalid="ignore"):
+ pve = 100.0 * (1.0 - sess_sse / sess_sst)
+ pve = np.where(sess_sst > 0, pve, 0.0)
- if len(pve) == 1:
- return pve[0]
- return np.array(pve)
+ if n_sessions == 1:
+ return float(pve[0])
+ return pve
def get_token_kernel_response(
self,
diff --git a/examples/train_etkn.py b/examples/train_etkn.py
index 3f8dc84..cb91361 100644
--- a/examples/train_etkn.py
+++ b/examples/train_etkn.py
@@ -11,7 +11,7 @@
from pytorch_lightning.loggers import CSVLogger
from ephys_tokenizer.configs import get_config
-from ephys_tokenizer.data.dataloader import CamcanGlasserDataModule
+from ephys_tokenizer.data.dataloader import EphysDataModule
from ephys_tokenizer.models import callbacks
from ephys_tokenizer.models.ephys_tokenizer import EphysTokenizerModule
from ephys_tokenizer.utils import plotting
@@ -71,7 +71,7 @@ def main(cfg: DictConfig):
include_subjects=subject_ids,
verbose=False,
)
- camcan_datamodule = CamcanGlasserDataModule(
+ camcan_datamodule = EphysDataModule(
dataset=camcan_data,
batch_size=batch_size,
val_split=0,