Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
**EphysTokenizer** is a data-driven, sample-level tokenizer for non-invasive human electrophysiological signals (MEG/EEG). It discretizes continuous neural time series into integer token sequences at each time step. By training an autoencoder with an RNN-based encoder and a convolutional decoder, the model learns a quantization scheme through signal reconstruction, enabling end-to-end tokenization directly from raw time-domain samples.

<div align="center">
<img src="assets/model_architecture.png" alt="EphysTokenizer Overview" width="25%">
<img src="assets/model_architecture.png" alt="EphysTokenizer Overview" width="40%">
<p><strong>Overview of the EphysTokenizer Architecture</strong></p>
</div>

Expand Down
210 changes: 203 additions & 7 deletions ephys_tokenizer/data/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
"""PyTorch DataLoader and Lightning DataModule for tokenizer training."""
"""PyTorch DataLoader and Lightning DataModule for tokenizer training.

Also provides a generic h5-backed ``Dataset`` (``H5Session`` / ``H5Dataset`` /
``build_h5_dataset``) that plugs into :class:`EphysDataModule` as a drop-in
replacement for ``pnpl.datasets.CamcanGlasser``.
"""

# Import packages
from __future__ import annotations

import os
import random

import h5py
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import random
import torch
from pathlib import Path
from torch.utils.data import (
ConcatDataset,
DataLoader,
Expand Down Expand Up @@ -45,9 +55,9 @@ def _default_worker_init_fn(worker_id: int) -> None:
torch.manual_seed(seed)


def _collate_camcan(batch: Sequence[dict]) -> Dict[str, Any]:
def _collate_default(batch: Sequence[dict]) -> Dict[str, Any]:
"""
Collate function for CamcanGlasser items.
Default collate function for session-window items.

Each item is expected to be:
{"data": np.ndarray (L, C), "times": np.ndarray (L,), "info": dict}
Expand Down Expand Up @@ -84,7 +94,7 @@ def _make_dataloader(
persistent_workers: Optional[bool] = True,
drop_last: Optional[bool] = False,
sampler: Optional[Sampler] = None,
collate_fn=_collate_camcan,
collate_fn=_collate_default,
) -> DataLoader:
"""
Creates a DataLoader for a given dataset.
Expand Down Expand Up @@ -130,9 +140,13 @@ def _make_dataloader(
)


class CamcanGlasserDataModule(pl.LightningDataModule):
class EphysDataModule(pl.LightningDataModule):
"""
Lightning DataModule for the CamCAN dataset.
Lightning DataModule for continuous session-based datasets.

Expects the input dataset to wrap a ``ConcatDataset`` of per-session
sub-datasets (accessible via ``dataset.dataset``). Each sub-dataset must
expose a ``.subject`` attribute if subject-level splitting is used.

Parameters
----------
Expand Down Expand Up @@ -441,3 +455,185 @@ def full_dataloader(self) -> DataLoader:
persistent_workers=self.persistent_workers,
drop_last=False,
)


# ---------------------------------------------------------------------------
# Generic h5-backed Dataset
# ---------------------------------------------------------------------------
#
# Mirrors the interface consumed by :class:`EphysDataModule`:
# - Top-level object exposes ``.dataset`` as a ``ConcatDataset`` of
# per-session sub-datasets.
# - Each sub-dataset has a ``.subject`` attribute for subject-level splits.
# - Each item is ``{"data": (L, C) float32, "times": (L,) float32,
# "info": dict}``.
#
# Each h5 file is expected to hold a single ``"data"`` dataset of shape
# ``(n_samples, n_channels)``.


class H5Session(Dataset):
"""Non-overlapping windowed view of one h5 session file.

Parameters
----------
h5_path : str
Path to a session h5 file containing a single ``"data"`` dataset of
shape ``(n_samples, n_channels)``.
window_len : int
Window length in samples; windows are non-overlapping. Any trailing
samples shorter than ``window_len`` are dropped.
sfreq : float
Sample rate in Hz, used only to populate the ``times`` array.
info : dict
Metadata dict copied into every item's ``info`` field. Must contain a
``"subject"`` key if subject-level splitting will be used downstream.
standardize : bool
If True, apply per-channel, per-session z-score standardisation
(subtract mean, divide by standard deviation).
"""

def __init__(
self,
h5_path: str,
window_len: int,
sfreq: float,
info: Dict[str, Any],
standardize: bool = True,
):
self.h5_path = str(h5_path)
self.window_len = int(window_len)
self.sfreq = float(sfreq)
self.info = dict(info)
self.standardize = bool(standardize)

# Required by EphysDataModule subject-splitting.
self.subject = self.info.get("subject")

with h5py.File(self.h5_path, "r") as f:
ds = f["data"]
self.n_samples = int(ds.shape[0])
self.n_channels = int(ds.shape[1])
if self.standardize:
data = ds[...].astype(np.float64)
self._mean = data.mean(axis=0).astype(np.float32)
self._std = data.std(axis=0).astype(np.float32)
else:
self._mean = None
self._std = None

self.n_windows = self.n_samples // self.window_len
# Opened lazily per process; re-opened if the PID changes so a handle
# opened in the main process isn't reused (and silently desynced) by a
# forked DataLoader worker.
self._h5: Optional[h5py.File] = None
self._h5_pid: Optional[int] = None

def __len__(self) -> int:
return self.n_windows

def __getstate__(self) -> Dict[str, Any]:
# Strip the handle so the dataset pickles cleanly into workers
# (spawn multiprocessing context).
state = self.__dict__.copy()
state["_h5"] = None
state["_h5_pid"] = None
return state

def _handle(self) -> h5py.File:
pid = os.getpid()
if self._h5 is None or self._h5_pid != pid:
self._h5 = h5py.File(self.h5_path, "r")
self._h5_pid = pid
return self._h5

def __getitem__(self, idx: int) -> Dict[str, Any]:
if idx < 0 or idx >= self.n_windows:
raise IndexError(idx)
start = idx * self.window_len
end = start + self.window_len
data = self._handle()["data"][start:end, :].astype(np.float32, copy=False)
if self.standardize:
data = (data - self._mean) / self._std
times = np.arange(start, end, dtype=np.float32) / self.sfreq
return {"data": data, "times": times, "info": self.info}


class H5Dataset(Dataset):
"""Concatenation of :class:`H5Session` datasets.

Exposes ``.dataset`` as a ``ConcatDataset`` so it is a drop-in replacement
for ``pnpl.datasets.CamcanGlasser`` when used with
:class:`EphysDataModule`.
"""

def __init__(self, sessions: Sequence[H5Session]):
if len(sessions) == 0:
raise ValueError("H5Dataset requires at least one session.")
self.dataset = ConcatDataset(list(sessions))

def __len__(self) -> int:
return len(self.dataset)

def __getitem__(self, idx: int) -> Dict[str, Any]:
return self.dataset[idx]


def build_h5_dataset(
sessions_csv: str,
h5_dir: str,
window_len: int,
sfreq: float = 250.0,
standardize: bool = True,
include_sessions: Optional[Sequence[str]] = None,
info_cols: Sequence[str] = (
"session",
"dataset",
"subject",
"task",
"system",
"age",
"sex",
),
) -> H5Dataset:
"""Build an :class:`H5Dataset` from a sessions CSV.

Parameters
----------
sessions_csv : str
Path to a CSV indexing h5 sessions. Must contain a ``session`` column
and columns named in ``info_cols``.
h5_dir : str
Directory containing ``{session}.h5`` files.
window_len : int
Number of samples per window.
sfreq : float
Sample rate in Hz (default 250 Hz).
standardize : bool
Apply per-session, per-channel z-score standardisation.
include_sessions : Optional[Sequence[str]]
If given, restrict to these session IDs.
info_cols : Sequence[str]
Columns copied from the CSV into each item's ``info`` dict.
"""
df = pd.read_csv(sessions_csv)
if include_sessions is not None:
df = df[df["session"].isin(set(include_sessions))]
if df.empty:
raise ValueError("No sessions selected from CSV.")

h5_dir_path = Path(h5_dir)
sessions: List[H5Session] = []
for _, row in df.iterrows():
info = {c: row[c] for c in info_cols if c in row}
h5_path = h5_dir_path / f"{row['session']}.h5"
sessions.append(
H5Session(
h5_path=str(h5_path),
window_len=window_len,
sfreq=sfreq,
info=info,
standardize=standardize,
)
)
return H5Dataset(sessions)
53 changes: 21 additions & 32 deletions ephys_tokenizer/models/ephys_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,12 +652,13 @@ def get_pve(

# Put model in evaluation mode
model = self.model.to(device).eval()
sequence_length = self.config.sequence_length
n_channels = self.config.n_channels

# Preallocate arrays for tokens (and weights if requested)
original_x = torch.empty((n_total_sequences, sequence_length, n_channels), dtype=torch.float32, device=device)
reconstructed_x = torch.empty((n_total_sequences, sequence_length, n_channels), dtype=torch.float32, device=device)
# Stream per-session sum-of-squared-error and sum-of-squared-total
# through the dataloader
n_sessions = len(ranges)
sess_sse = np.zeros(n_sessions, dtype=np.float64)
sess_sst = np.zeros(n_sessions, dtype=np.float64)
range_starts = np.array([r[1] for r in ranges], dtype=np.int64)

idx = 0
with torch.inference_mode():
Expand All @@ -668,37 +669,25 @@ def get_pve(

_, rx, _ = model(x) # shape: (B, L, C)

bsz = x.shape[0] # actual batch size (last batch may be smaller)
original_x[idx:idx + bsz].copy_(x)
reconstructed_x[idx:idx + bsz].copy_(rx)
bsz = x.shape[0]
sse_b = ((x - rx) ** 2).sum(dim=(1, 2)).cpu().numpy() # (B,)
sst_b = (x ** 2).sum(dim=(1, 2)).cpu().numpy() # (B,)

# Map each window to its session index (ranges are contiguous)
positions = np.arange(idx, idx + bsz, dtype=np.int64)
sess_ix = np.searchsorted(range_starts, positions, side="right") - 1
np.add.at(sess_sse, sess_ix, sse_b)
np.add.at(sess_sst, sess_ix, sst_b)

idx += bsz

# Move all tensors to CPU at once
original_x = original_x.cpu().numpy()
reconstructed_x = reconstructed_x.cpu().numpy()

# Split by subject ranges
all_original_x = []
all_reconstructed_x = []
for _, start, end in ranges:
all_original_x.append(original_x[start:end].reshape(-1, n_channels))
all_reconstructed_x.append(reconstructed_x[start:end].reshape(-1, n_channels))
# all_*_x.shape: (N, T, C)

pve = []
for x, rx in tqdm(
zip(all_original_x, all_reconstructed_x),
desc="Calculating Percentage of Variance Explained ...",
total=len(all_original_x),
):
pve.append(
100 * (1 - np.sum((x - rx) ** 2) / np.sum(x ** 2))
)
with np.errstate(divide="ignore", invalid="ignore"):
pve = 100.0 * (1.0 - sess_sse / sess_sst)
pve = np.where(sess_sst > 0, pve, 0.0)

if len(pve) == 1:
return pve[0]
return np.array(pve)
if n_sessions == 1:
return float(pve[0])
return pve

def get_token_kernel_response(
self,
Expand Down
4 changes: 2 additions & 2 deletions examples/train_etkn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pytorch_lightning.loggers import CSVLogger

from ephys_tokenizer.configs import get_config
from ephys_tokenizer.data.dataloader import CamcanGlasserDataModule
from ephys_tokenizer.data.dataloader import EphysDataModule
from ephys_tokenizer.models import callbacks
from ephys_tokenizer.models.ephys_tokenizer import EphysTokenizerModule
from ephys_tokenizer.utils import plotting
Expand Down Expand Up @@ -71,7 +71,7 @@ def main(cfg: DictConfig):
include_subjects=subject_ids,
verbose=False,
)
camcan_datamodule = CamcanGlasserDataModule(
camcan_datamodule = EphysDataModule(
dataset=camcan_data,
batch_size=batch_size,
val_split=0,
Expand Down