From d5da207537ac2d72f99803da29b053e7da7c5a58 Mon Sep 17 00:00:00 2001 From: Riley Date: Wed, 16 Jul 2025 15:02:39 -0400 Subject: [PATCH 01/13] Add video.py, change Python requirement --- fptools/viz/signal.py | 2 +- fptools/viz/video.py | 30 ++++++++++++++++++++++++++++++ pyproject.toml | 2 +- 3 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 fptools/viz/video.py diff --git a/fptools/viz/signal.py b/fptools/viz/signal.py index b89ce83..3b5eb4d 100644 --- a/fptools/viz/signal.py +++ b/fptools/viz/signal.py @@ -11,7 +11,7 @@ from typing import Any, Literal, Optional, Union from fptools.io import Signal, SessionCollection -from .common import Palette, get_colormap +from fptools.viz.common import Palette, get_colormap def plot_signal( diff --git a/fptools/viz/video.py b/fptools/viz/video.py new file mode 100644 index 0000000..c0fb097 --- /dev/null +++ b/fptools/viz/video.py @@ -0,0 +1,30 @@ +import os +import cv2 +import numpy as np + + +def get_frame_image(video: str, frame_idx: int) -> np.ndarray: + """Get a frame from a video using openCV + + Args: + video: string path to the video file + frame_idx: integer index of the frame to grab + + Returns: + numpy array of the image at `frame_idx` + """ + cap = cv2.VideoCapture(video) + if not cap.isOpened(): + raise Exception(f"Could not open video file: {video}") + return + + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) + + ret, frame = cap.read() + + if not ret: + print("Error: Could not read frame.") + return + + cap.release() + return frame \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 18abc54..2fd4465 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ authors = [ {name = "Joshua K. Thackray", email = "thackray@rutgers.edu"}, ] description="Collection of tools for working with fiber photometry data." -requires-python = ">=3.12,<=3.13" +requires-python = ">=3.12,<=3.13.2" keywords = ["fiber photometry", "behavior"] license = {text = "BSD-3-Clause"} classifiers = [ From 0828c8e3dba7718162c67936e306f63488217a93 Mon Sep 17 00:00:00 2001 From: Riley Date: Tue, 23 Sep 2025 16:27:39 -0400 Subject: [PATCH 02/13] modify Sesssion to support DLC data --- fptools/io/session.py | 76 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 76 insertions(+) diff --git a/fptools/io/session.py b/fptools/io/session.py index 9f82d20..bf367b7 100644 --- a/fptools/io/session.py +++ b/fptools/io/session.py @@ -25,6 +25,15 @@ def empty_array() -> np.ndarray: """ return np.ndarray([], dtype=np.float64) +def empty_df() -> pd.DataFrame: + """Create an empty Pandas dataframe + + Returns: + empty pd.DataFrame + + """ + return pd.DataFrame() + class Session(object): """Holds data and metadata for a single session.""" @@ -37,6 +46,7 @@ def __init__(self) -> None: self.signals: dict[str, Signal] = {} self.epocs: dict[str, np.ndarray] = defaultdict(empty_array) # epocs are numpy arrays, default to empty array self.scalars: dict[str, np.ndarray] = defaultdict(empty_array) # scalars are numpy arrays, default to empty array + self.dlc: dict[str, np.ndarray] = defaultdict(empty_array) # dlc data are structured numpy arrays, default to empty array def describe(self, as_str: bool = False) -> Union[str, None]: """Describe this session. @@ -92,6 +102,16 @@ def describe(self, as_str: bool = False) -> Union[str, None]: buffer += " < No Signals Available >\n" buffer += "\n" + buffer += "DLC Data:\n" + if len(self.dlc) > 0: + buffer += f"{len(self.dlc)} DLC arrays found: \n" + for k, v in self.dlc.items(): + buffer += f" {k}: \n" + buffer += f" array_shape = {v.shape} \n" + else: + buffer += " < No DLC arrays Available >\n" + buffer += "\n" + if as_str: return buffer else: @@ -168,6 +188,21 @@ def rename_epoc(self, old_name: str, new_name: str) -> None: self.epocs[new_name] = self.epocs[old_name] self.epocs.pop(old_name) + + def rename_dlc(self, old_name: str, new_name: str) -> None: + """Rename a dlc array, from `old_name` to `new_name`. + + Raises an error if the new dlc array name already exists. + + Args: + old_name: the current name for the dlc array + new_name: the new name for the dlc array + """ + if new_name in self.dlc: + raise KeyError(f"Key `{new_name}` already exists in data!") + + self.dlc[new_name] = self.dlc[old_name] + self.dlc.pop(old_name) def epoc_dataframe(self, include_epocs: FieldList = "all", include_meta: FieldList = "all") -> pd.DataFrame: """Produce a dataframe with epoc data and metadata. @@ -234,6 +269,30 @@ def scalar_dataframe(self, include_scalars: FieldList = "all", include_meta: Fie scalars.append({**meta, "scalar_name": sn, "scalar_value": self.scalars[sn]}) return pd.DataFrame(scalars) + + def dlc_dataframe(self, id: Union[str, int] = 0) -> pd.DataFrame: + """ + + Args: + id: identifier to select which dlc data to use in the dataframe. If str is provided, will access that named dlc data. If int is provided, will use the data from that index position among the dlc data. + + By default index 0 will be accessed. + + Returns: + DataFrame with data from this session + """ + + if isinstance(id, str): + return pd.DataFrame(self.dlc[id]) + + elif isinstance(id, int): + data_list = list(self.dlc.values()) + return pd.DataFrame(data_list[id]) + + else: + raise TypeError( + 'Invalid `id` argument data type. Supported data identifier types are str and int.' + ) def __eq__(self, value: object) -> bool: """Test this Session for equality to another Session. @@ -299,6 +358,7 @@ def _estimate_memory_use_itemized(self) -> dict[str, int]: **{f"signal.{sig.name}": sig._estimate_memory_use() for sig in self.signals.values()}, **{f"epocs.{k}": sys.getsizeof(k) + v.nbytes for k, v in self.epocs.items()}, **{f"scalars.{k}": sys.getsizeof(k) + v.nbytes for k, v in self.scalars.items()}, + **{f"dlc.{k}": sys.getsizeof(k) + v.nbytes for k, v in self.dlc.items()}, } def _estimate_memory_use(self) -> int: @@ -341,6 +401,11 @@ def save(self, path: str): h5.create_group("/scalars") for k, scalar in self.scalars.items(): h5.create_dataset(f"/scalars/{k}", data=scalar, compression="gzip") + + # save dlc data + h5.create_group("/dlc") + for k, dlc in self.dlc.items(): + h5.create_dataset(f"/dlc/{k}", data=dlc) # save metadata meta_group = h5.create_group("/metadata") @@ -424,6 +489,11 @@ def load(cls, path: str) -> "Session": for scalar_name in h5["/scalars"].keys(): session.scalars[scalar_name] = h5[f"/scalars/{scalar_name}"][()] + # read dlc + if "/dlc" in h5: + for dlc_name in h5["/dlc"].keys(): + session.dlc[dlc_name] = h5[f"/dlc/{dlc_name}"][()] + # read metadata if "/metadata" in h5: for meta_name in h5[f"/metadata"].keys(): @@ -765,6 +835,12 @@ def describe(self, as_str: bool = False) -> Union[str, None]: buffer += f'({v}) "{k}"\n' buffer += "\n" + dlcs = Counter([item for session in self for item in session.dlc.keys()]) + buffer += "DLC Structured Numpy Arrays present in data:\n" + for k, v in dlcs.items(): + buffer += f'({v}) "{k}"\n' + buffer += "\n" + if as_str: return buffer else: From d22bf88cd1b7c77b5a363e6279d0f0e33c7a7f71 Mon Sep 17 00:00:00 2001 From: Riley Date: Tue, 23 Sep 2025 16:32:26 -0400 Subject: [PATCH 03/13] implement DLC data loading in concert with TDT --- fptools/io/data_loader.py | 7 +- fptools/io/tdt_with_dlc.py | 142 +++++++++++++++++++++++++++++++++++++ 2 files changed, 147 insertions(+), 2 deletions(-) create mode 100644 fptools/io/tdt_with_dlc.py diff --git a/fptools/io/data_loader.py b/fptools/io/data_loader.py index 8331103..ce5d27c 100644 --- a/fptools/io/data_loader.py +++ b/fptools/io/data_loader.py @@ -11,6 +11,7 @@ from .common import DataLocator, DataTypeAdaptor from .med_associates import find_ma_blocks from .tdt import find_tdt_blocks +from .tdt_with_dlc import FindTDTDLCBlocks from .session import Session, SessionCollection from tqdm.auto import tqdm @@ -57,7 +58,7 @@ def load_data( manifest_path: Optional[str] = None, manifest_index: str = "blockname", max_workers: Optional[int] = None, - locator: Union[Literal["auto", "tdt", "ma"], DataLocator] = "auto", + locator: Union[Literal["auto", "tdt", "tdt_dlc", "ma"], DataLocator] = "auto", preprocess: Optional[Processor] = None, cache: bool = True, cache_dir: str = "cache", @@ -214,7 +215,7 @@ def _load( return session -def _get_locator(locator: Union[Literal["auto", "tdt", "ma"], DataLocator] = "auto") -> DataLocator: +def _get_locator(locator: Union[Literal["auto", "tdt", "tdt_dlc", "ma"], DataLocator] = "auto") -> DataLocator: """Translate a flexible locator argument to a concrete DataLoader implementation. Args: @@ -227,6 +228,8 @@ def _get_locator(locator: Union[Literal["auto", "tdt", "ma"], DataLocator] = "au return _find_any_data elif locator == "tdt": return find_tdt_blocks + elif locator == "tdt_dlc": + return FindTDTDLCBlocks() elif locator == "ma": return find_ma_blocks else: diff --git a/fptools/io/tdt_with_dlc.py b/fptools/io/tdt_with_dlc.py new file mode 100644 index 0000000..9b66fb6 --- /dev/null +++ b/fptools/io/tdt_with_dlc.py @@ -0,0 +1,142 @@ +import glob +import os +from typing import Optional +from pathlib import Path + +import pandas as pd + +import tdt + +from .common import DataTypeAdaptor +from .session import Session, Signal +from .tdt import TDT_EXCLUDE_STREAMS, TDTLoader + + +def has_neighboring_dlc_h5(tbk): + + dlc_id_substrs = ['DLC', 'shuffle', 'snapshot'] + neighbor_files = glob.glob(os.path.join(os.path.dirname(tbk), '*')) + neighbor_files = [file for file in neighbor_files if file.endswith('.h5')] + + dlc_files = [file for file in neighbor_files if all(sub in file for sub in dlc_id_substrs)] + + if len(dlc_files) > 0: + return True + + return False + + +class FindTDTDLCBlocks: + + def __init__(self, model_name: Optional[list[str]] = None, filtered_only: bool = True): + self.model_name = model_name + self.filtered_only = filtered_only + + + def __call__(self, path: str): + """Data Locator for TDT blocks with DLC data. + + Given a path to a directory, will search that path recursively for TDT blocks that contain DLC output files in .h5 format. + + Args: + path: path to search for TDT blocks with DLC data + + Returns: + list of DataTypeAdaptor, each adaptor corresponding to one session, of data to be loaded + """ + + tbk_files = glob.glob(os.path.join(path, "**/*.[tT][bB][kK]"), recursive=True) + # tbk_files = [file for file in tbk_files if has_neighboring_dlc_h5(file)] + + items_out = [] + for tbk in tbk_files: + adapt = DataTypeAdaptor() + adapt.path = os.path.dirname(tbk) # the directory for the block + adapt.name = os.path.basename(adapt.path) # the name of the directory + adapt.loaders.append(TDTLoader(exclude_streams=TDT_EXCLUDE_STREAMS)) + adapt.loaders.append(DLCLoader(model_name=self.model_name, filtered_only=self.filtered_only)) + items_out.append(adapt) + + return items_out + +# def find_tdt_w_dlc_blocks(path: str) -> list[DataTypeAdaptor]: +# """Data Locator for TDT blocks with DLC data. + +# Given a path to a directory, will search that path recursively for TDT blocks that contain DLC output files in .h5 format. + +# Args: +# path: path to search for TDT blocks with DLC data + +# Returns: +# list of DataTypeAdaptor, each adaptor corresponding to one session, of data to be loaded +# """ + +# tbk_files = glob.glob(os.path.join(path, "**/*.[tT][bB][kK]"), recursive=True) +# tbk_files = [file for file in tbk_files if has_neighboring_dlc_h5(file)] + +# items_out = [] +# for tbk in tbk_files: +# adapt = DataTypeAdaptor() +# adapt.path = os.path.dirname(tbk) # the directory for the block +# adapt.name = os.path.basename(adapt.path) # the name of the directory +# adapt.loaders.append(TDTLoader(exclude_streams=TDT_EXCLUDE_STREAMS)) +# adapt.loaders.append(DLCLoader()) +# items_out.append(adapt) + +# return items_out + + +class DLCLoader: + def __init__( + self, + model_name: Optional[list[str]] = None, + filtered_only: bool = True + ) -> None: + """Initialize this DLCLoader.""" + self.model_name = model_name + self.filtered_only = filtered_only + + def __call__(self, session: Session, path: str) -> Session: + """Data Loader for DLC .h5 files in TDT blocks. + + Args: + session: the session for data to be loaded into + path: path to a TDT block folder containing DLC data + + Returns: + Session object with data added + """ + + # find the dlc .h5 file + if self.model_name is None: + pattern = os.path.join(path, f'*DLC*.h5') + files = glob.glob(pattern) + if len(files) <= 0: + raise FileNotFoundError(f"Could not find any DLC files in block {session.name}!") + + for file in files: + df = pd.read_hdf(file) + nparray = df.to_records(index=False) + key = Path(file).stem # TODO: instead look into the df for the model name + session.dlc[key] = nparray + else: + for mn in self.model_name: + if self.filtered_only: + pattern = os.path.join(path, f'*{mn}*_filtered.h5') + else: + pattern = os.path.join(path, f'*{mn}*.h5') + files = glob.glob(pattern) + if len(files) <= 0: + raise FileNotFoundError(f"Could not find any DLC files for model \"{mn}\" in block {session.name}!") + + for file in files: + df = pd.read_hdf(file) + nparray = df.to_records(index=False) + key = f"{mn}" + if "_filtered" in file: + key += "_filtered" + else: + key += "_unfiltered" + session.dlc[key] = nparray + + return session From d2f39342227d8258ff48e365490feb32b7527d60 Mon Sep 17 00:00:00 2001 From: Riley Date: Tue, 23 Sep 2025 16:52:52 -0400 Subject: [PATCH 04/13] satisfy black --- fptools/io/session.py | 31 +++++++++++++++++-------------- fptools/io/tdt_with_dlc.py | 34 +++++++++++++++------------------- fptools/viz/video.py | 2 +- 3 files changed, 33 insertions(+), 34 deletions(-) diff --git a/fptools/io/session.py b/fptools/io/session.py index bf367b7..c82f8e8 100644 --- a/fptools/io/session.py +++ b/fptools/io/session.py @@ -25,12 +25,13 @@ def empty_array() -> np.ndarray: """ return np.ndarray([], dtype=np.float64) + def empty_df() -> pd.DataFrame: """Create an empty Pandas dataframe Returns: empty pd.DataFrame - + """ return pd.DataFrame() @@ -46,7 +47,7 @@ def __init__(self) -> None: self.signals: dict[str, Signal] = {} self.epocs: dict[str, np.ndarray] = defaultdict(empty_array) # epocs are numpy arrays, default to empty array self.scalars: dict[str, np.ndarray] = defaultdict(empty_array) # scalars are numpy arrays, default to empty array - self.dlc: dict[str, np.ndarray] = defaultdict(empty_array) # dlc data are structured numpy arrays, default to empty array + self.dlc: dict[str, np.ndarray] = defaultdict(empty_array) # dlc data are structured numpy arrays, default to empty array def describe(self, as_str: bool = False) -> Union[str, None]: """Describe this session. @@ -188,7 +189,7 @@ def rename_epoc(self, old_name: str, new_name: str) -> None: self.epocs[new_name] = self.epocs[old_name] self.epocs.pop(old_name) - + def rename_dlc(self, old_name: str, new_name: str) -> None: """Rename a dlc array, from `old_name` to `new_name`. @@ -200,7 +201,7 @@ def rename_dlc(self, old_name: str, new_name: str) -> None: """ if new_name in self.dlc: raise KeyError(f"Key `{new_name}` already exists in data!") - + self.dlc[new_name] = self.dlc[old_name] self.dlc.pop(old_name) @@ -269,10 +270,10 @@ def scalar_dataframe(self, include_scalars: FieldList = "all", include_meta: Fie scalars.append({**meta, "scalar_name": sn, "scalar_value": self.scalars[sn]}) return pd.DataFrame(scalars) - + def dlc_dataframe(self, id: Union[str, int] = 0) -> pd.DataFrame: """ - + Args: id: identifier to select which dlc data to use in the dataframe. If str is provided, will access that named dlc data. If int is provided, will use the data from that index position among the dlc data. @@ -284,15 +285,13 @@ def dlc_dataframe(self, id: Union[str, int] = 0) -> pd.DataFrame: if isinstance(id, str): return pd.DataFrame(self.dlc[id]) - + elif isinstance(id, int): data_list = list(self.dlc.values()) return pd.DataFrame(data_list[id]) - + else: - raise TypeError( - 'Invalid `id` argument data type. Supported data identifier types are str and int.' - ) + raise TypeError("Invalid `id` argument data type. Supported data identifier types are str and int.") def __eq__(self, value: object) -> bool: """Test this Session for equality to another Session. @@ -358,7 +357,7 @@ def _estimate_memory_use_itemized(self) -> dict[str, int]: **{f"signal.{sig.name}": sig._estimate_memory_use() for sig in self.signals.values()}, **{f"epocs.{k}": sys.getsizeof(k) + v.nbytes for k, v in self.epocs.items()}, **{f"scalars.{k}": sys.getsizeof(k) + v.nbytes for k, v in self.scalars.items()}, - **{f"dlc.{k}": sys.getsizeof(k) + v.nbytes for k, v in self.dlc.items()}, + **{f"dlc.{k}": sys.getsizeof(k) + v.nbytes for k, v in self.dlc.items()}, } def _estimate_memory_use(self) -> int: @@ -401,7 +400,7 @@ def save(self, path: str): h5.create_group("/scalars") for k, scalar in self.scalars.items(): h5.create_dataset(f"/scalars/{k}", data=scalar, compression="gzip") - + # save dlc data h5.create_group("/dlc") for k, dlc in self.dlc.items(): @@ -473,7 +472,11 @@ def load(cls, path: str) -> "Session": for signame in h5["/signals"].keys(): sig_group = h5[f"/signals/{signame}"] sig = Signal( - signame, sig_group["signal"][()], time=sig_group["time"][()], fs=sig_group.attrs["fs"], units=sig_group.attrs["units"] + signame, + sig_group["signal"][()], + time=sig_group["time"][()], + fs=sig_group.attrs["fs"], + units=sig_group.attrs["units"], ) for mark_name in sig_group["marks"].keys(): sig.marks[mark_name] = sig_group[f"marks/{mark_name}"][()] diff --git a/fptools/io/tdt_with_dlc.py b/fptools/io/tdt_with_dlc.py index 9b66fb6..8bca850 100644 --- a/fptools/io/tdt_with_dlc.py +++ b/fptools/io/tdt_with_dlc.py @@ -14,9 +14,9 @@ def has_neighboring_dlc_h5(tbk): - dlc_id_substrs = ['DLC', 'shuffle', 'snapshot'] - neighbor_files = glob.glob(os.path.join(os.path.dirname(tbk), '*')) - neighbor_files = [file for file in neighbor_files if file.endswith('.h5')] + dlc_id_substrs = ["DLC", "shuffle", "snapshot"] + neighbor_files = glob.glob(os.path.join(os.path.dirname(tbk), "*")) + neighbor_files = [file for file in neighbor_files if file.endswith(".h5")] dlc_files = [file for file in neighbor_files if all(sub in file for sub in dlc_id_substrs)] @@ -32,7 +32,6 @@ def __init__(self, model_name: Optional[list[str]] = None, filtered_only: bool = self.model_name = model_name self.filtered_only = filtered_only - def __call__(self, path: str): """Data Locator for TDT blocks with DLC data. @@ -44,7 +43,7 @@ def __call__(self, path: str): Returns: list of DataTypeAdaptor, each adaptor corresponding to one session, of data to be loaded """ - + tbk_files = glob.glob(os.path.join(path, "**/*.[tT][bB][kK]"), recursive=True) # tbk_files = [file for file in tbk_files if has_neighboring_dlc_h5(file)] @@ -56,9 +55,10 @@ def __call__(self, path: str): adapt.loaders.append(TDTLoader(exclude_streams=TDT_EXCLUDE_STREAMS)) adapt.loaders.append(DLCLoader(model_name=self.model_name, filtered_only=self.filtered_only)) items_out.append(adapt) - + return items_out + # def find_tdt_w_dlc_blocks(path: str) -> list[DataTypeAdaptor]: # """Data Locator for TDT blocks with DLC data. @@ -70,7 +70,7 @@ def __call__(self, path: str): # Returns: # list of DataTypeAdaptor, each adaptor corresponding to one session, of data to be loaded # """ - + # tbk_files = glob.glob(os.path.join(path, "**/*.[tT][bB][kK]"), recursive=True) # tbk_files = [file for file in tbk_files if has_neighboring_dlc_h5(file)] @@ -82,16 +82,12 @@ def __call__(self, path: str): # adapt.loaders.append(TDTLoader(exclude_streams=TDT_EXCLUDE_STREAMS)) # adapt.loaders.append(DLCLoader()) # items_out.append(adapt) - + # return items_out class DLCLoader: - def __init__( - self, - model_name: Optional[list[str]] = None, - filtered_only: bool = True - ) -> None: + def __init__(self, model_name: Optional[list[str]] = None, filtered_only: bool = True) -> None: """Initialize this DLCLoader.""" self.model_name = model_name self.filtered_only = filtered_only @@ -109,25 +105,25 @@ def __call__(self, session: Session, path: str) -> Session: # find the dlc .h5 file if self.model_name is None: - pattern = os.path.join(path, f'*DLC*.h5') + pattern = os.path.join(path, f"*DLC*.h5") files = glob.glob(pattern) if len(files) <= 0: - raise FileNotFoundError(f"Could not find any DLC files in block {session.name}!") + raise FileNotFoundError(f"Could not find any DLC files in block {session.name}!") for file in files: df = pd.read_hdf(file) nparray = df.to_records(index=False) - key = Path(file).stem # TODO: instead look into the df for the model name + key = Path(file).stem # TODO: instead look into the df for the model name session.dlc[key] = nparray else: for mn in self.model_name: if self.filtered_only: - pattern = os.path.join(path, f'*{mn}*_filtered.h5') + pattern = os.path.join(path, f"*{mn}*_filtered.h5") else: - pattern = os.path.join(path, f'*{mn}*.h5') + pattern = os.path.join(path, f"*{mn}*.h5") files = glob.glob(pattern) if len(files) <= 0: - raise FileNotFoundError(f"Could not find any DLC files for model \"{mn}\" in block {session.name}!") + raise FileNotFoundError(f'Could not find any DLC files for model "{mn}" in block {session.name}!') for file in files: df = pd.read_hdf(file) diff --git a/fptools/viz/video.py b/fptools/viz/video.py index c0fb097..6c4819b 100644 --- a/fptools/viz/video.py +++ b/fptools/viz/video.py @@ -27,4 +27,4 @@ def get_frame_image(video: str, frame_idx: int) -> np.ndarray: return cap.release() - return frame \ No newline at end of file + return frame From ce8395211d4af953bf3e55c348475c71d5b7ab4e Mon Sep 17 00:00:00 2001 From: Riley Date: Tue, 23 Sep 2025 17:00:08 -0400 Subject: [PATCH 05/13] satisfy pydocstyle --- fptools/io/session.py | 5 ++--- fptools/io/tdt_with_dlc.py | 17 ++++++++++++++--- fptools/viz/video.py | 2 +- 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/fptools/io/session.py b/fptools/io/session.py index c82f8e8..5e7317a 100644 --- a/fptools/io/session.py +++ b/fptools/io/session.py @@ -27,7 +27,7 @@ def empty_array() -> np.ndarray: def empty_df() -> pd.DataFrame: - """Create an empty Pandas dataframe + """Create an empty Pandas dataframe. Returns: empty pd.DataFrame @@ -272,7 +272,7 @@ def scalar_dataframe(self, include_scalars: FieldList = "all", include_meta: Fie return pd.DataFrame(scalars) def dlc_dataframe(self, id: Union[str, int] = 0) -> pd.DataFrame: - """ + """Fetch DLC data as a pandas dataframe. Args: id: identifier to select which dlc data to use in the dataframe. If str is provided, will access that named dlc data. If int is provided, will use the data from that index position among the dlc data. @@ -282,7 +282,6 @@ def dlc_dataframe(self, id: Union[str, int] = 0) -> pd.DataFrame: Returns: DataFrame with data from this session """ - if isinstance(id, str): return pd.DataFrame(self.dlc[id]) diff --git a/fptools/io/tdt_with_dlc.py b/fptools/io/tdt_with_dlc.py index 8bca850..040980d 100644 --- a/fptools/io/tdt_with_dlc.py +++ b/fptools/io/tdt_with_dlc.py @@ -12,8 +12,15 @@ from .tdt import TDT_EXCLUDE_STREAMS, TDTLoader -def has_neighboring_dlc_h5(tbk): +def has_neighboring_dlc_h5(tbk) -> bool: + """Checks if the TBK file has a neighboring H5 file that looks like a DLC data output. + Args: + tbk (str): TBK file path to check + + Returns: + True if the TBK file has DLC neighbors, otherwise false + """ dlc_id_substrs = ["DLC", "shuffle", "snapshot"] neighbor_files = glob.glob(os.path.join(os.path.dirname(tbk), "*")) neighbor_files = [file for file in neighbor_files if file.endswith(".h5")] @@ -29,6 +36,12 @@ def has_neighboring_dlc_h5(tbk): class FindTDTDLCBlocks: def __init__(self, model_name: Optional[list[str]] = None, filtered_only: bool = True): + """Initialize this TDT-DLC Data Locator. + + Args: + model_name: If provided, only look for DLC files with that model name(s), If None, load all files that look like DLC data + filtered_only: If true, only load filtered DLC data, otherwise, load any DLC data + """ self.model_name = model_name self.filtered_only = filtered_only @@ -43,7 +56,6 @@ def __call__(self, path: str): Returns: list of DataTypeAdaptor, each adaptor corresponding to one session, of data to be loaded """ - tbk_files = glob.glob(os.path.join(path, "**/*.[tT][bB][kK]"), recursive=True) # tbk_files = [file for file in tbk_files if has_neighboring_dlc_h5(file)] @@ -102,7 +114,6 @@ def __call__(self, session: Session, path: str) -> Session: Returns: Session object with data added """ - # find the dlc .h5 file if self.model_name is None: pattern = os.path.join(path, f"*DLC*.h5") diff --git a/fptools/viz/video.py b/fptools/viz/video.py index 6c4819b..e6f1ce5 100644 --- a/fptools/viz/video.py +++ b/fptools/viz/video.py @@ -4,7 +4,7 @@ def get_frame_image(video: str, frame_idx: int) -> np.ndarray: - """Get a frame from a video using openCV + """Get a frame from a video using openCV. Args: video: string path to the video file From 855429639412428fdeedf3391208f2375a65398d Mon Sep 17 00:00:00 2001 From: Riley Date: Tue, 23 Sep 2025 17:06:48 -0400 Subject: [PATCH 06/13] satisfy mypy, black --- fptools/io/tdt_with_dlc.py | 2 +- fptools/viz/video.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/fptools/io/tdt_with_dlc.py b/fptools/io/tdt_with_dlc.py index 040980d..1781e16 100644 --- a/fptools/io/tdt_with_dlc.py +++ b/fptools/io/tdt_with_dlc.py @@ -37,7 +37,7 @@ class FindTDTDLCBlocks: def __init__(self, model_name: Optional[list[str]] = None, filtered_only: bool = True): """Initialize this TDT-DLC Data Locator. - + Args: model_name: If provided, only look for DLC files with that model name(s), If None, load all files that look like DLC data filtered_only: If true, only load filtered DLC data, otherwise, load any DLC data diff --git a/fptools/viz/video.py b/fptools/viz/video.py index e6f1ce5..d30c8ad 100644 --- a/fptools/viz/video.py +++ b/fptools/viz/video.py @@ -1,9 +1,10 @@ import os +from typing import Union import cv2 import numpy as np -def get_frame_image(video: str, frame_idx: int) -> np.ndarray: +def get_frame_image(video: str, frame_idx: int) -> Union[np.ndarray, None]: """Get a frame from a video using openCV. Args: @@ -11,12 +12,11 @@ def get_frame_image(video: str, frame_idx: int) -> np.ndarray: frame_idx: integer index of the frame to grab Returns: - numpy array of the image at `frame_idx` + numpy array of the image at `frame_idx`, Or None if the frame could not be retrieved """ cap = cv2.VideoCapture(video) if not cap.isOpened(): raise Exception(f"Could not open video file: {video}") - return cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) @@ -24,7 +24,7 @@ def get_frame_image(video: str, frame_idx: int) -> np.ndarray: if not ret: print("Error: Could not read frame.") - return + return None cap.release() return frame From af2a085310c0d19ce2985d68476f31c2eb7a64a9 Mon Sep 17 00:00:00 2001 From: Riley Date: Wed, 24 Sep 2025 12:54:11 -0400 Subject: [PATCH 07/13] add dlc_interp step skeleton --- fptools/preprocess/steps/dlc_interp.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 fptools/preprocess/steps/dlc_interp.py diff --git a/fptools/preprocess/steps/dlc_interp.py b/fptools/preprocess/steps/dlc_interp.py new file mode 100644 index 0000000..53e4544 --- /dev/null +++ b/fptools/preprocess/steps/dlc_interp.py @@ -0,0 +1,19 @@ +from matplotlib.axes import Axes +from fptools.io.session import Session +from fptools.preprocess.common import ProcessorThatPlots + + +class DLCInterpolation(ProcessorThatPlots): + """A `Processor` that interpolates missing frames in dlc data.""" + + def __init__( + self, + ): + """Initialize this Processor.""" + pass + + def __call__(self, session: Session) -> Session: + pass + + def plot(self, session: Session, ax: Axes): + pass From 1060a878f2fb7f69319f357d1aad92b3c78961a7 Mon Sep 17 00:00:00 2001 From: wang-riley Date: Thu, 16 Oct 2025 13:49:34 -0400 Subject: [PATCH 08/13] added model name detection for dlcloader if model name not specified by user --- fptools/io/tdt_with_dlc.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/fptools/io/tdt_with_dlc.py b/fptools/io/tdt_with_dlc.py index 1781e16..c69d119 100644 --- a/fptools/io/tdt_with_dlc.py +++ b/fptools/io/tdt_with_dlc.py @@ -118,14 +118,23 @@ def __call__(self, session: Session, path: str) -> Session: if self.model_name is None: pattern = os.path.join(path, f"*DLC*.h5") files = glob.glob(pattern) + if len(files) <= 0: raise FileNotFoundError(f"Could not find any DLC files in block {session.name}!") for file in files: df = pd.read_hdf(file) + key = f"{df.columns[0][0]}" + df.columns = df.columns.droplevel(level=0).to_flat_index() nparray = df.to_records(index=False) - key = Path(file).stem # TODO: instead look into the df for the model name - session.dlc[key] = nparray + + if "_filtered" in file: + key += "_filtered" + session.dlc[key] = nparray + else: + if not self.filtered_only: + key += "_unfiltered" + session.dlc[key] = nparray else: for mn in self.model_name: if self.filtered_only: @@ -138,6 +147,7 @@ def __call__(self, session: Session, path: str) -> Session: for file in files: df = pd.read_hdf(file) + df.columns = df.columns.droplevel(level=0).to_flat_index() nparray = df.to_records(index=False) key = f"{mn}" if "_filtered" in file: From 65f5a8ce08dbeba52bb6b95baeaf2d63a9bd941b Mon Sep 17 00:00:00 2001 From: wang-riley Date: Tue, 28 Oct 2025 15:26:22 -0400 Subject: [PATCH 09/13] added analysis and misc to Session/SessionCollect --- fptools/io/session.py | 240 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 239 insertions(+), 1 deletion(-) diff --git a/fptools/io/session.py b/fptools/io/session.py index 5e7317a..0c7b973 100644 --- a/fptools/io/session.py +++ b/fptools/io/session.py @@ -48,6 +48,8 @@ def __init__(self) -> None: self.epocs: dict[str, np.ndarray] = defaultdict(empty_array) # epocs are numpy arrays, default to empty array self.scalars: dict[str, np.ndarray] = defaultdict(empty_array) # scalars are numpy arrays, default to empty array self.dlc: dict[str, np.ndarray] = defaultdict(empty_array) # dlc data are structured numpy arrays, default to empty array + self.analysis: dict[str, np.ndarray] = defaultdict(empty_array) #analysis data are strictly 1d numpy arrays, default to empty array + self.misc: dict[str, Any] = {} # WARNING: careful what datatypes stored in misc, some datatypes will not play well with saving into an hdf5 def describe(self, as_str: bool = False) -> Union[str, None]: """Describe this session. @@ -113,6 +115,23 @@ def describe(self, as_str: bool = False) -> Union[str, None]: buffer += " < No DLC arrays Available >\n" buffer += "\n" + buffer += "Analysis Data:\n" + if len(self.analysis) > 0: + buffer += f"{len(self.analysis)} Analysis data arrays found: \n" + for k, v in self.analysis.items(): + buffer += f" {k}: \n" + buffer += f" array_shape = {v.shape} \n" + else: + buffer += " < No analysis data arrays available >\n" + buffer += "\n" + + buffer += "Misc:\n" + if len(self.misc) > 0: + buffer += f"{len(self.analysis)} Misc items found: \n" + for k, v in self.misc.items(): + buffer += f" {k}: \n" + buffer += f" data type = {type(v)} \n" + if as_str: return buffer else: @@ -190,6 +209,29 @@ def rename_epoc(self, old_name: str, new_name: str) -> None: self.epocs[new_name] = self.epocs[old_name] self.epocs.pop(old_name) + def add_dlc(self, dlc: Union[pd.DataFrame, np.ndarray], name: str, overwrite: bool = False) -> None: + """Add a DLC data to this Session. + + Raises an error if the new DLC data name already exists and `overwrite` is not True. + + Args: + dlc: pd.DataFrame or structured numpy arrays, the DLC data to add to this Session + overwrite: if True, allow overwriting a pre-existing signal with the same name, if False, will raise instead. + """ + if name in self.dlc and not overwrite: + raise KeyError(f"Key `{name}` already exists in data!") + + if isinstance(dlc, pd.DataFrame): + dlc.columns = dlc.columns.to_flat_index() + nparray = dlc.to_records(index=False) + self.dlc[name] = nparray + + elif isinstance(dlc, np.ndarray): + self.dlc[name] = dlc + + else: + raise TypeError("Invalid `dlc` argument data type. Supported data types are pd.DataFrame and numpy arrays.") + def rename_dlc(self, old_name: str, new_name: str) -> None: """Rename a dlc array, from `old_name` to `new_name`. @@ -205,6 +247,72 @@ def rename_dlc(self, old_name: str, new_name: str) -> None: self.dlc[new_name] = self.dlc[old_name] self.dlc.pop(old_name) + def add_analysis(self, arr: np.ndarray, name: str, overwrite: bool = False) -> None: + """Add analysis data to this Session. + + Raises an error if the new analysis data name already exists and `overwrite` is not True. + + Args: + analysis: 1D numpy array of analysis data + overwrite: if True, allow overwriting a pre-existing analysis with the same name, if False, will raise error instead + """ + if name in self.analysis and not overwrite: + raise KeyError(f"Key `{name}` already exists in analysis data!") + + if isinstance(arr, np.ndarray): + if arr.ndim != 1: + raise ValueError(f"Analysis array must be 1-dimensional, but has {arr.ndim} dimensions.") + + self.analysis[name] = arr + + else: + raise TypeError("Invalid `dlc` argument data type. Supported data types are numpy arrays.") + + def rename_analysis(self, old_name: str, new_name: str) -> None: + """Rename an analysis array, from `old_name` to `new_name`. + + Raises an error if the new analysis array name already exists. + + Args: + old_name: the current name for the analysis array + new_name: the new name for the analysis array + """ + if new_name in self.analysis: + raise KeyError(f"Key `{new_name}` already exists in analysis data!") + + self.analysis[new_name] = self.analysis[old_name] + self.analysis.pop(old_name) + + def add_misc(self, data: Any, name: str, overwrite: bool = False) -> None: + """Add Misc item to this Session + + Raises an error if the new misc ite, already exists and `overwrite` is not True. + + Args: + data: misc item + overwrite: if True, allow overwriting a pre-existing misc item with the same name, if False, will raise error instead + """ + if name in self.misc and not overwrite: + raise KeyError(f"Key `{name}` already exists in misc items!") + + else: + self.misc[name] = data + + def rename_misc(self, old_name: str, new_name: str) -> None: + """Rename a misc item, from `old_name` to `new_name`. + + Raises an error if the new analysis array name already exists. + + Args: + old_name: the current name for the misc item + new_name: the new name for the misc item + """ + if new_name in self.misc: + raise KeyError(f"Key `{new_name} already exists in misc items!") + + self.misc[new_name] = self.misc[old_name] + self.misc.pop(old_name) + def epoc_dataframe(self, include_epocs: FieldList = "all", include_meta: FieldList = "all") -> pd.DataFrame: """Produce a dataframe with epoc data and metadata. @@ -291,6 +399,46 @@ def dlc_dataframe(self, id: Union[str, int] = 0) -> pd.DataFrame: else: raise TypeError("Invalid `id` argument data type. Supported data identifier types are str and int.") + + def analysis_dataframe(self, include_analysis: FieldList = "all", include_meta: FieldList = "all") -> pd.DataFrame: + """Produce a dataframe with analysis data and metadata. + + Args: + include_analysis: list of analysis array names to include in the dataframe. Special str "all" is also accepted. + include_meta: list of metadata fields to include in the dataframe. Special str "all" is also accepted. + + Returns: + DataFrame with data from this session + """ + # determine metadata fields to include + if include_meta == "all": + meta = self.metadata + else: + meta = {k: v for k, v in self.metadata.items() if k in include_meta} + + # determine arrays to include + if include_analysis == "all": + analysis_names = list(self.analysis.keys()) + else: + analysis_names = [k for k in self.analysis.keys() if k in include_analysis] + + # TODO: iterate arrays and include any the user requested + # also add in any requested metadata + data = [] + for k, v in self.analysis.items(): + if k in analysis_names: + if len(v) == 1: + for value in v: + data.append({**meta, "metric": k, "obs": np.nan, "value": value}) + else: + obsn = 1 + for value in v: + data.append({**meta, "metric": k, "obs": obsn,"value": value}) + obsn += 1 + + df = pd.DataFrame(data) + + return df def __eq__(self, value: object) -> bool: """Test this Session for equality to another Session. @@ -357,6 +505,8 @@ def _estimate_memory_use_itemized(self) -> dict[str, int]: **{f"epocs.{k}": sys.getsizeof(k) + v.nbytes for k, v in self.epocs.items()}, **{f"scalars.{k}": sys.getsizeof(k) + v.nbytes for k, v in self.scalars.items()}, **{f"dlc.{k}": sys.getsizeof(k) + v.nbytes for k, v in self.dlc.items()}, + **{f"analysis.{k}": sys.getsizeof(k) + v.nbytes for k, v in self.analysis.items()}, + **{f"misc.{k}": sys.getsizeof(k) + v.nbytes for k, v in self.misc.items()}, } def _estimate_memory_use(self) -> int: @@ -405,6 +555,16 @@ def save(self, path: str): for k, dlc in self.dlc.items(): h5.create_dataset(f"/dlc/{k}", data=dlc) + # save analysis data + h5.create_group("/analysis") + for k, analysis in self.analysis.items(): + h5.create_dataset(f"/analysis/{k}", data=analysis) + + # save misc data + h5.create_group("/misc") + for k, misc in self.misc.items(): + h5.create_dataset(f"/misc/{k}", data=misc) + # save metadata meta_group = h5.create_group("/metadata") for k, v in self.metadata.items(): @@ -496,6 +656,16 @@ def load(cls, path: str) -> "Session": for dlc_name in h5["/dlc"].keys(): session.dlc[dlc_name] = h5[f"/dlc/{dlc_name}"][()] + # read analysis + if "/analysis" in h5: + for analysis_name in h5["/analysis"].keys(): + session.analysis[analysis_name] = h5[f"/analysis/{analysis_name}"][()] + + # read misc + if "/misc" in h5: + for misc_name in h5["/misc"].keys(): + session.misc[misc_name] = h5[f"/misc/{misc_name}"][()] + # read metadata if "/metadata" in h5: for meta_name in h5[f"/metadata"].keys(): @@ -610,6 +780,26 @@ def rename_scalar(self, old_name: str, new_name: str) -> None: for item in self: item.rename_scalar(old_name, new_name) + def rename_analysis(self, old_name: str, new_name: str) -> None: + """Rename an analysis on each session in this collection. + + Args: + old_name: current name of the analysis + new_name: the new name for the analysis + """ + for item in self: + item.rename_analysis(old_name, new_name) + + def rename_misc(self, old_name: str, new_name: str) -> None: + """Rename a misc item on each session in this collection. + + Args: + old_name: current name of the misc item + new_name: the new name for the misc item + """ + for item in self: + item.rename_misc(old_name, new_name) + def filter(self, predicate: Callable[[Session], bool]) -> "SessionCollection": """Filter the items in this collection, returning a new `SessionCollection` containing sessions which pass `predicate`. @@ -781,6 +971,42 @@ def signal_dataframe(self, signal: str, include_meta: FieldList = "all") -> pd.D dfs.append(df) return pd.concat(dfs, ignore_index=True) + + def analysis_dataframe(self, include_analysis: FieldList = "all", include_meta: FieldList = "all") -> pd.DataFrame: + """Produce a dataframe with analysis data and metadata across all sessions in this collection. + + Args: + include_analysis: list of analysis names to include in the dataframe. Special str "all" is also accepted + include_meta: list of metadata fields to include in the dataframe. Special str "all" is also accepted + + Returns: + DataFrame with analysis data from across this collection + """ + dfs = [session.analysis_dataframe(include_analysis=include_analysis, include_meta=include_meta) for session in self] + return pd.concat(dfs).reset_index(drop=True) + + def add_analysis(self, name: str, analysis: Callable[[Session], np.ndarray]) -> None: + """Apply an analysis function to each session in this collection, adding the results to each session's analysis attribute. + + Args: + name: Name of the new analysis data key + analysis: callable accepting a single session and returning a 1d numpy array + """ + for session in self: + session.add_analysis(analysis(session), name) + + def map(self, action: Callable[[Session], Session]) -> "SessionCollection": + """Apply a function to each session in this collection, returning a new collection with the results. + + Args: + action: callable accepting a single session and returning a new session + + Returns: + a new `SessionCollection` containing the results of `action` + """ + sc = type(self)(action(item) for item in self) + sc.__meta_meta.update(**copy.deepcopy(self.__meta_meta)) + return sc def aggregate_signals(self, name: str, method: Union[None, str, np.ufunc, Callable[[np.ndarray], np.ndarray]] = "median") -> Signal: """Aggregate signals across sessions in this collection for the signal name `name`. @@ -791,7 +1017,7 @@ def aggregate_signals(self, name: str, method: Union[None, str, np.ufunc, Callab Returns: Aggregated `Signal` - """ + """ signals = [s for s in self.get_signal(name) if s.nobs > 0] if len(signals) <= 0: raise ValueError("No signals were passed!") @@ -843,6 +1069,18 @@ def describe(self, as_str: bool = False) -> Union[str, None]: buffer += f'({v}) "{k}"\n' buffer += "\n" + analysis = Counter([item for session in self for item in session.analysis.keys()]) + buffer += "Analysis arrays present in data:\n" + for k, v in analysis.items(): + buffer += f'({v}) "{k}"\n' + buffer += "\n" + + misc = Counter([item for session in self for item in session.misc.keys()]) + buffer += "Misc items present in data:\n" + for k, v in misc.items(): + buffer += f'({v}) "{k}"\n' + buffer += "\n" + if as_str: return buffer else: From 0e22c1aab59de52e7671d1afde898309441af050 Mon Sep 17 00:00:00 2001 From: wang-riley Date: Thu, 6 Nov 2025 10:40:07 -0500 Subject: [PATCH 10/13] correct misc item count display and improve dlc/misc data handling in HDF5 export --- fptools/io/session.py | 59 +++++++++++++++++++++++++++++++++---------- 1 file changed, 46 insertions(+), 13 deletions(-) diff --git a/fptools/io/session.py b/fptools/io/session.py index 0c7b973..f9290bd 100644 --- a/fptools/io/session.py +++ b/fptools/io/session.py @@ -127,7 +127,7 @@ def describe(self, as_str: bool = False) -> Union[str, None]: buffer += "Misc:\n" if len(self.misc) > 0: - buffer += f"{len(self.analysis)} Misc items found: \n" + buffer += f"{len(self.misc)} Misc items found: \n" for k, v in self.misc.items(): buffer += f" {k}: \n" buffer += f" data type = {type(v)} \n" @@ -221,16 +221,21 @@ def add_dlc(self, dlc: Union[pd.DataFrame, np.ndarray], name: str, overwrite: bo if name in self.dlc and not overwrite: raise KeyError(f"Key `{name}` already exists in data!") - if isinstance(dlc, pd.DataFrame): - dlc.columns = dlc.columns.to_flat_index() - nparray = dlc.to_records(index=False) - self.dlc[name] = nparray + # if isinstance(dlc, pd.DataFrame): + # dlc.columns = dlc.columns.to_flat_index() + + # string_cols = dlc.select_dtypes(include=['str']).columns + # for col in string_cols: + # dlc[col] = dlc[col].astype(np.bytes_).astype('S50') + + # nparray = dlc.to_records(index=False) + # self.dlc[name] = nparray - elif isinstance(dlc, np.ndarray): - self.dlc[name] = dlc + # elif isinstance(dlc, np.ndarray): + self.dlc[name] = dlc - else: - raise TypeError("Invalid `dlc` argument data type. Supported data types are pd.DataFrame and numpy arrays.") + # else: + # raise TypeError("Invalid `dlc` argument data type. Supported data types are pd.DataFrame and numpy arrays.") def rename_dlc(self, old_name: str, new_name: str) -> None: """Rename a dlc array, from `old_name` to `new_name`. @@ -266,7 +271,7 @@ def add_analysis(self, arr: np.ndarray, name: str, overwrite: bool = False) -> N self.analysis[name] = arr else: - raise TypeError("Invalid `dlc` argument data type. Supported data types are numpy arrays.") + raise TypeError("Invalid `arr` argument data type. Supported data types are numpy arrays.") def rename_analysis(self, old_name: str, new_name: str) -> None: """Rename an analysis array, from `old_name` to `new_name`. @@ -381,7 +386,6 @@ def scalar_dataframe(self, include_scalars: FieldList = "all", include_meta: Fie def dlc_dataframe(self, id: Union[str, int] = 0) -> pd.DataFrame: """Fetch DLC data as a pandas dataframe. - Args: id: identifier to select which dlc data to use in the dataframe. If str is provided, will access that named dlc data. If int is provided, will use the data from that index position among the dlc data. @@ -519,6 +523,7 @@ def save(self, path: str): Args: path: path where the data should be saved """ + with h5py.File(path, mode="w") as h5: # save name h5.create_dataset("/name", data=self.name) @@ -553,7 +558,21 @@ def save(self, path: str): # save dlc data h5.create_group("/dlc") for k, dlc in self.dlc.items(): - h5.create_dataset(f"/dlc/{k}", data=dlc) + if isinstance(dlc, pd.DataFrame): + dlc.columns = dlc.columns.to_flat_index() + + for col in dlc.select_dtypes(include='object').columns: + dlc[col] = dlc[col].apply(lambda x: ','.join(map(str, x)) if isinstance(x, (list, tuple)) else x) + + for col in dlc.select_dtypes(include='object').columns: + dlc[col] = dlc[col].astype(np.bytes_).astype('S50') + # string_cols = [col for col in dlc.columns if isinstance(dlc[col].iloc[2], str)] + # for col in string_cols: + # dlc[col] = dlc[col].astype(np.bytes_).astype('S50') + nparray = dlc.to_records(index=False) + h5.create_dataset(f"/dlc/{k}", data=nparray) + elif isinstance(dlc, np.ndarray): + h5.create_dataset(f"/dlc/{k}", data=dlc) # save analysis data h5.create_group("/analysis") @@ -563,7 +582,21 @@ def save(self, path: str): # save misc data h5.create_group("/misc") for k, misc in self.misc.items(): - h5.create_dataset(f"/misc/{k}", data=misc) + if isinstance(misc, pd.DataFrame): + misc.columns = misc.columns.to_flat_index() + + for col in dlc.select_dtypes(include='object').columns: + dlc[col] = dlc[col].apply(lambda x: ','.join(map(str, x)) if isinstance(x, (list, tuple)) else x) + + for col in dlc.select_dtypes(include='object').columns: + dlc[col] = dlc[col].astype(np.bytes_).astype('S50') + # string_cols = [col for col in dlc.columns if isinstance(dlc[col].iloc[2], str)] + # for col in string_cols: + # misc[col] = misc[col].astype(np.bytes_).astype('S50') + nparray = misc.to_records(index=False) + h5.create_dataset(f"/misc/{k}", data=nparray) + else: + h5.create_dataset(f"/misc/{k}", data=misc) # save metadata meta_group = h5.create_group("/metadata") From 57760066348112b92087662c100005396f5b8919 Mon Sep 17 00:00:00 2001 From: wang-riley Date: Thu, 6 Nov 2025 10:52:57 -0500 Subject: [PATCH 11/13] add Session.add_epoc and SessionCollection.add_epoc methods --- fptools/io/session.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/fptools/io/session.py b/fptools/io/session.py index f9290bd..8a22cff 100644 --- a/fptools/io/session.py +++ b/fptools/io/session.py @@ -137,6 +137,27 @@ def describe(self, as_str: bool = False) -> Union[str, None]: else: print(buffer) return None + + def add_epoc(self, arr: np.ndarray, name: str, overwrite: bool = False) -> None: + """Add epoc data to this Session. + + Raises an error if the new epoc name already exists and `overwrite` is not True. + + Args: + epoc: 1D numpy array of epoc timestamps + overwrite: if True, allow overwriting a pre-existing epoc with the same name, if False, will raise error instead + """ + if name in self.epocs and not overwrite: + raise KeyError(f"Key `{name}` already exists in analysis data!") + + if isinstance(arr, np.ndarray): + if arr.ndim != 1: + raise ValueError(f"Epoc data must be 1-dimensional, but has {arr.ndim} dimensions.") + + self.epocs[name] = arr + + else: + raise TypeError("Invalid `arr` argument data type. Supported data types are numpy arrays.") def add_signal(self, signal: Signal, overwrite: bool = False) -> None: """Add a signal to this Session. @@ -942,6 +963,16 @@ def get_signal(self, name: str) -> list[Signal]: List of Signals, each corresponding to a single session """ return [item.signals[name] for item in self if name in item.signals] + + def add_analysis(self, name: str, epoc_func: Callable[[Session], np.ndarray]) -> None: + """Apply an epoc function to each session in this collection, adding the results to each session's epocs attribute. + + Args: + name: Name of the new analysis data key + epoc_func: callable accepting a single session and returning a 1d numpy array + """ + for session in self: + session.add_epoc(epoc_func(session), name) def epoc_dataframe(self, include_epocs: FieldList = "all", include_meta: FieldList = "all") -> pd.DataFrame: """Produce a dataframe with epoc data and metadata across all the sessions in this collection. From 9c433c4dc2af1f188be70c608616c7ce3376c38e Mon Sep 17 00:00:00 2001 From: wang-riley Date: Wed, 12 Nov 2025 09:41:38 -0500 Subject: [PATCH 12/13] added decoding for stored dataframes --- fptools/io/session.py | 39 +++++++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/fptools/io/session.py b/fptools/io/session.py index 8a22cff..e19b496 100644 --- a/fptools/io/session.py +++ b/fptools/io/session.py @@ -6,6 +6,7 @@ import os import sys from typing import Any, Callable, Literal, Optional, Union +import ast import h5py import numpy as np @@ -35,6 +36,14 @@ def empty_df() -> pd.DataFrame: """ return pd.DataFrame() +def decode_byteseq(x): + """Decode an object encoded by utf-8. + """ + try: + return ast.literal_eval(x.decode("utf-8")) + except ValueError: + return x.decode("utf-8") + class Session(object): """Holds data and metadata for a single session.""" @@ -416,15 +425,17 @@ def dlc_dataframe(self, id: Union[str, int] = 0) -> pd.DataFrame: DataFrame with data from this session """ if isinstance(id, str): - return pd.DataFrame(self.dlc[id]) - + df = pd.DataFrame(self.dlc[id]) elif isinstance(id, int): data_list = list(self.dlc.values()) - return pd.DataFrame(data_list[id]) - + df = pd.DataFrame(data_list[id]) else: raise TypeError("Invalid `id` argument data type. Supported data identifier types are str and int.") + df[df.select_dtypes(include='object').columns] = df.select_dtypes(include='object').map(decode_byteseq) + + return df + def analysis_dataframe(self, include_analysis: FieldList = "all", include_meta: FieldList = "all") -> pd.DataFrame: """Produce a dataframe with analysis data and metadata. @@ -447,8 +458,6 @@ def analysis_dataframe(self, include_analysis: FieldList = "all", include_meta: else: analysis_names = [k for k in self.analysis.keys() if k in include_analysis] - # TODO: iterate arrays and include any the user requested - # also add in any requested metadata data = [] for k, v in self.analysis.items(): if k in analysis_names: @@ -464,6 +473,24 @@ def analysis_dataframe(self, include_analysis: FieldList = "all", include_meta: df = pd.DataFrame(data) return df + + def misc_dataframe(self, id: str) -> pd.DataFrame: + """Fetch misc data as a pandas dataframe. + Args: + id: name of the misc item to load + + Returns: + DataFrame with data from this session + """ + if isinstance(self.misc[id], np.array): + df = pd.DataFrame(self.misc[id]) + + else: + raise TypeError("Only numpy arrays can be loaded as a pandas DataFrame.") + + df[df.select_dtypes(include='object').columns] = df.select_dtypes(include='object').map(decode_byteseq) + + return df def __eq__(self, value: object) -> bool: """Test this Session for equality to another Session. From 9a414a49f86d66ed85efc5dac802156f86b95766 Mon Sep 17 00:00:00 2001 From: wang-riley Date: Thu, 5 Feb 2026 19:01:28 -0500 Subject: [PATCH 13/13] fix decoding, add run analysis to SessionCollection --- fptools/io/session.py | 131 ++++++++++++++++++++++++++++++++++++++---- 1 file changed, 120 insertions(+), 11 deletions(-) diff --git a/fptools/io/session.py b/fptools/io/session.py index e19b496..da02464 100644 --- a/fptools/io/session.py +++ b/fptools/io/session.py @@ -7,6 +7,11 @@ import sys from typing import Any, Callable, Literal, Optional, Union import ast +import re +import multiprocessing +import traceback +from concurrent.futures import ProcessPoolExecutor, as_completed, Future +from tqdm.auto import tqdm import h5py import numpy as np @@ -39,10 +44,26 @@ def empty_df() -> pd.DataFrame: def decode_byteseq(x): """Decode an object encoded by utf-8. """ - try: - return ast.literal_eval(x.decode("utf-8")) - except ValueError: - return x.decode("utf-8") + if isinstance(x, bytes): + try: + return ast.literal_eval(x.decode("utf-8")) + except ValueError: + if isinstance(x, bytes): + return x.decode("utf-8") + else: + return x + +def remove_non_letters(text: str) -> str: + """Remove all non-letter characters from a string + + Args: + text: str to process + + Returns: + str containing only non-letter characters""" + # [^a-zA-Z] matches any character not in the range a-z or A-Z + # It replaces the matched characters with an empty string '' + return re.sub(r'[^a-zA-Z]', '', text) class Session(object): @@ -467,7 +488,7 @@ def analysis_dataframe(self, include_analysis: FieldList = "all", include_meta: else: obsn = 1 for value in v: - data.append({**meta, "metric": k, "obs": obsn,"value": value}) + data.append({**meta, "metric": k, "obs": int(obsn),"value": value}) obsn += 1 df = pd.DataFrame(data) @@ -482,7 +503,7 @@ def misc_dataframe(self, id: str) -> pd.DataFrame: Returns: DataFrame with data from this session """ - if isinstance(self.misc[id], np.array): + if isinstance(self.misc[id], np.ndarray): df = pd.DataFrame(self.misc[id]) else: @@ -625,13 +646,15 @@ def save(self, path: str): # save analysis data h5.create_group("/analysis") for k, analysis in self.analysis.items(): - h5.create_dataset(f"/analysis/{k}", data=analysis) - + h5.create_dataset(f"/analysis/{k}", data=np.atleast_1d(analysis).astype(float), compression="gzip") + # h5.create_dataset(f"/analysis/{k}", data=np.atleast_1d(analysis), dtype="f8", compression="gzip") + # save misc data h5.create_group("/misc") for k, misc in self.misc.items(): if isinstance(misc, pd.DataFrame): misc.columns = misc.columns.to_flat_index() + misc.rename(mapper=remove_non_letters, axis=1) for col in dlc.select_dtypes(include='object').columns: dlc[col] = dlc[col].apply(lambda x: ','.join(map(str, x)) if isinstance(x, (list, tuple)) else x) @@ -642,9 +665,12 @@ def save(self, path: str): # for col in string_cols: # misc[col] = misc[col].astype(np.bytes_).astype('S50') nparray = misc.to_records(index=False) - h5.create_dataset(f"/misc/{k}", data=nparray) - else: + h5.create_dataset(f"/misc/{k}", data=nparray) + elif misc.dtype.names is not None: h5.create_dataset(f"/misc/{k}", data=misc) + else: + h5.create_dataset(f"/misc/{k}", data=np.atleast_1d(misc).astype(float), compression="gzip") + # h5.create_dataset(f"/misc/{k}", data=np.atleast_1d(misc), dtype="f8", compression="gzip") # save metadata meta_group = h5.create_group("/metadata") @@ -991,7 +1017,7 @@ def get_signal(self, name: str) -> list[Signal]: """ return [item.signals[name] for item in self if name in item.signals] - def add_analysis(self, name: str, epoc_func: Callable[[Session], np.ndarray]) -> None: + def add_epoc(self, name: str, epoc_func: Callable[[Session], np.ndarray]) -> None: """Apply an epoc function to each session in this collection, adding the results to each session's epocs attribute. Args: @@ -1086,6 +1112,89 @@ def add_analysis(self, name: str, analysis: Callable[[Session], np.ndarray]) -> for session in self: session.add_analysis(analysis(session), name) + def run_analysis(self, analysis: Callable[[Session], Session], max_workers: Optional[int] = None) -> "SessionCollection": + """Run an analysis function on each session in this collection and return a new SessionCollection. + + Args: + analysis: callable accepting a single session with optional additional keyword arguments and returning a new session + max_workers: number of workers in the process pool for running analysis. If None, defaults to the number of CPUs on the machine. + + Returns: + a new `SessionCollection` containing results of `analysis` + + """ + futures: dict[Future[Session], str] = {} + context = multiprocessing.get_context("spawn") + max_tasks_per_child = 1 + + sc = SessionCollection() + + with ProcessPoolExecutor(max_workers=max_workers, mp_context=context, max_tasks_per_child=max_tasks_per_child) as executor: + + # iterate over all Sessions in the SessionCollection + + for s in self: + + # submit the task to the pool + f = executor.submit(analysis, s) + futures[f] = s.metadata["blockname"] + + # compile the new SessionCollection + for f in tqdm(as_completed(futures), total=len(futures)): + try: + sc.append(f.result()) + except Exception as e: + tqdm.write( + f'Problem running analysis at "{futures[f]}":\n{traceback.format_exc()}\nThe session will be missing from the resultant SessionCollection!\n' + ) + pass + + return sc + + def run_analysis_kw_test(self, analysis: Callable[[Session], Session], max_workers: Optional[int] = None, analysis_kwargs: Optional[dict] = None) -> "SessionCollection": + """Run an analysis function on each session in this collection and return a new SessionCollection. + + Args: + analysis: callable accepting a single session with optional additional keyword arguments and returning a new session + max_workers: number of workers in the process pool for running analysis. If None, defaults to the number of CPUs on the machine. + ana_kwargs: kwargs to pass to analysis callable + + Returns: + a new `SessionCollection` containing results of `analysis` + + """ + futures: dict[Future[Session], str] = {} + context = multiprocessing.get_context("spawn") + max_tasks_per_child = 1 + + _analysis_kwargs = {} + if analysis_kwargs is not None: + _analysis_kwargs.update(analysis_kwargs) + + sc = SessionCollection() + + with ProcessPoolExecutor(max_workers=max_workers, mp_context=context, max_tasks_per_child=max_tasks_per_child) as executor: + + # iterate over all Sessions in the SessionCollection + + for s in self: + + # submit the task to the pool + f = executor.submit(analysis, s, **_analysis_kwargs) + futures[f] = s.metadata["blockname"] + + # compile the new SessionCollection + for f in tqdm(as_completed(futures), total=len(futures)): + try: + sc.append(f.result()) + except Exception as e: + tqdm.write( + f'Problem running analysis at "{futures[f]}":\n{traceback.format_exc()}\nThe session will be missing from the resultant SessionCollection!\n' + ) + pass + + return sc + def map(self, action: Callable[[Session], Session]) -> "SessionCollection": """Apply a function to each session in this collection, returning a new collection with the results.