diff --git a/fptools/io/data_loader.py b/fptools/io/data_loader.py index 8331103..ce5d27c 100644 --- a/fptools/io/data_loader.py +++ b/fptools/io/data_loader.py @@ -11,6 +11,7 @@ from .common import DataLocator, DataTypeAdaptor from .med_associates import find_ma_blocks from .tdt import find_tdt_blocks +from .tdt_with_dlc import FindTDTDLCBlocks from .session import Session, SessionCollection from tqdm.auto import tqdm @@ -57,7 +58,7 @@ def load_data( manifest_path: Optional[str] = None, manifest_index: str = "blockname", max_workers: Optional[int] = None, - locator: Union[Literal["auto", "tdt", "ma"], DataLocator] = "auto", + locator: Union[Literal["auto", "tdt", "tdt_dlc", "ma"], DataLocator] = "auto", preprocess: Optional[Processor] = None, cache: bool = True, cache_dir: str = "cache", @@ -214,7 +215,7 @@ def _load( return session -def _get_locator(locator: Union[Literal["auto", "tdt", "ma"], DataLocator] = "auto") -> DataLocator: +def _get_locator(locator: Union[Literal["auto", "tdt", "tdt_dlc", "ma"], DataLocator] = "auto") -> DataLocator: """Translate a flexible locator argument to a concrete DataLoader implementation. Args: @@ -227,6 +228,8 @@ def _get_locator(locator: Union[Literal["auto", "tdt", "ma"], DataLocator] = "au return _find_any_data elif locator == "tdt": return find_tdt_blocks + elif locator == "tdt_dlc": + return FindTDTDLCBlocks() elif locator == "ma": return find_ma_blocks else: diff --git a/fptools/io/session.py b/fptools/io/session.py index 9f82d20..da02464 100644 --- a/fptools/io/session.py +++ b/fptools/io/session.py @@ -6,6 +6,12 @@ import os import sys from typing import Any, Callable, Literal, Optional, Union +import ast +import re +import multiprocessing +import traceback +from concurrent.futures import ProcessPoolExecutor, as_completed, Future +from tqdm.auto import tqdm import h5py import numpy as np @@ -26,6 +32,40 @@ def empty_array() -> np.ndarray: return np.ndarray([], dtype=np.float64) +def empty_df() -> pd.DataFrame: + """Create an empty Pandas dataframe. + + Returns: + empty pd.DataFrame + + """ + return pd.DataFrame() + +def decode_byteseq(x): + """Decode an object encoded by utf-8. + """ + if isinstance(x, bytes): + try: + return ast.literal_eval(x.decode("utf-8")) + except ValueError: + if isinstance(x, bytes): + return x.decode("utf-8") + else: + return x + +def remove_non_letters(text: str) -> str: + """Remove all non-letter characters from a string + + Args: + text: str to process + + Returns: + str containing only non-letter characters""" + # [^a-zA-Z] matches any character not in the range a-z or A-Z + # It replaces the matched characters with an empty string '' + return re.sub(r'[^a-zA-Z]', '', text) + + class Session(object): """Holds data and metadata for a single session.""" @@ -37,6 +77,9 @@ def __init__(self) -> None: self.signals: dict[str, Signal] = {} self.epocs: dict[str, np.ndarray] = defaultdict(empty_array) # epocs are numpy arrays, default to empty array self.scalars: dict[str, np.ndarray] = defaultdict(empty_array) # scalars are numpy arrays, default to empty array + self.dlc: dict[str, np.ndarray] = defaultdict(empty_array) # dlc data are structured numpy arrays, default to empty array + self.analysis: dict[str, np.ndarray] = defaultdict(empty_array) #analysis data are strictly 1d numpy arrays, default to empty array + self.misc: dict[str, Any] = {} # WARNING: careful what datatypes stored in misc, some datatypes will not play well with saving into an hdf5 def describe(self, as_str: bool = False) -> Union[str, None]: """Describe this session. @@ -92,11 +135,59 @@ def describe(self, as_str: bool = False) -> Union[str, None]: buffer += " < No Signals Available >\n" buffer += "\n" + buffer += "DLC Data:\n" + if len(self.dlc) > 0: + buffer += f"{len(self.dlc)} DLC arrays found: \n" + for k, v in self.dlc.items(): + buffer += f" {k}: \n" + buffer += f" array_shape = {v.shape} \n" + else: + buffer += " < No DLC arrays Available >\n" + buffer += "\n" + + buffer += "Analysis Data:\n" + if len(self.analysis) > 0: + buffer += f"{len(self.analysis)} Analysis data arrays found: \n" + for k, v in self.analysis.items(): + buffer += f" {k}: \n" + buffer += f" array_shape = {v.shape} \n" + else: + buffer += " < No analysis data arrays available >\n" + buffer += "\n" + + buffer += "Misc:\n" + if len(self.misc) > 0: + buffer += f"{len(self.misc)} Misc items found: \n" + for k, v in self.misc.items(): + buffer += f" {k}: \n" + buffer += f" data type = {type(v)} \n" + if as_str: return buffer else: print(buffer) return None + + def add_epoc(self, arr: np.ndarray, name: str, overwrite: bool = False) -> None: + """Add epoc data to this Session. + + Raises an error if the new epoc name already exists and `overwrite` is not True. + + Args: + epoc: 1D numpy array of epoc timestamps + overwrite: if True, allow overwriting a pre-existing epoc with the same name, if False, will raise error instead + """ + if name in self.epocs and not overwrite: + raise KeyError(f"Key `{name}` already exists in analysis data!") + + if isinstance(arr, np.ndarray): + if arr.ndim != 1: + raise ValueError(f"Epoc data must be 1-dimensional, but has {arr.ndim} dimensions.") + + self.epocs[name] = arr + + else: + raise TypeError("Invalid `arr` argument data type. Supported data types are numpy arrays.") def add_signal(self, signal: Signal, overwrite: bool = False) -> None: """Add a signal to this Session. @@ -169,6 +260,115 @@ def rename_epoc(self, old_name: str, new_name: str) -> None: self.epocs[new_name] = self.epocs[old_name] self.epocs.pop(old_name) + def add_dlc(self, dlc: Union[pd.DataFrame, np.ndarray], name: str, overwrite: bool = False) -> None: + """Add a DLC data to this Session. + + Raises an error if the new DLC data name already exists and `overwrite` is not True. + + Args: + dlc: pd.DataFrame or structured numpy arrays, the DLC data to add to this Session + overwrite: if True, allow overwriting a pre-existing signal with the same name, if False, will raise instead. + """ + if name in self.dlc and not overwrite: + raise KeyError(f"Key `{name}` already exists in data!") + + # if isinstance(dlc, pd.DataFrame): + # dlc.columns = dlc.columns.to_flat_index() + + # string_cols = dlc.select_dtypes(include=['str']).columns + # for col in string_cols: + # dlc[col] = dlc[col].astype(np.bytes_).astype('S50') + + # nparray = dlc.to_records(index=False) + # self.dlc[name] = nparray + + # elif isinstance(dlc, np.ndarray): + self.dlc[name] = dlc + + # else: + # raise TypeError("Invalid `dlc` argument data type. Supported data types are pd.DataFrame and numpy arrays.") + + def rename_dlc(self, old_name: str, new_name: str) -> None: + """Rename a dlc array, from `old_name` to `new_name`. + + Raises an error if the new dlc array name already exists. + + Args: + old_name: the current name for the dlc array + new_name: the new name for the dlc array + """ + if new_name in self.dlc: + raise KeyError(f"Key `{new_name}` already exists in data!") + + self.dlc[new_name] = self.dlc[old_name] + self.dlc.pop(old_name) + + def add_analysis(self, arr: np.ndarray, name: str, overwrite: bool = False) -> None: + """Add analysis data to this Session. + + Raises an error if the new analysis data name already exists and `overwrite` is not True. + + Args: + analysis: 1D numpy array of analysis data + overwrite: if True, allow overwriting a pre-existing analysis with the same name, if False, will raise error instead + """ + if name in self.analysis and not overwrite: + raise KeyError(f"Key `{name}` already exists in analysis data!") + + if isinstance(arr, np.ndarray): + if arr.ndim != 1: + raise ValueError(f"Analysis array must be 1-dimensional, but has {arr.ndim} dimensions.") + + self.analysis[name] = arr + + else: + raise TypeError("Invalid `arr` argument data type. Supported data types are numpy arrays.") + + def rename_analysis(self, old_name: str, new_name: str) -> None: + """Rename an analysis array, from `old_name` to `new_name`. + + Raises an error if the new analysis array name already exists. + + Args: + old_name: the current name for the analysis array + new_name: the new name for the analysis array + """ + if new_name in self.analysis: + raise KeyError(f"Key `{new_name}` already exists in analysis data!") + + self.analysis[new_name] = self.analysis[old_name] + self.analysis.pop(old_name) + + def add_misc(self, data: Any, name: str, overwrite: bool = False) -> None: + """Add Misc item to this Session + + Raises an error if the new misc ite, already exists and `overwrite` is not True. + + Args: + data: misc item + overwrite: if True, allow overwriting a pre-existing misc item with the same name, if False, will raise error instead + """ + if name in self.misc and not overwrite: + raise KeyError(f"Key `{name}` already exists in misc items!") + + else: + self.misc[name] = data + + def rename_misc(self, old_name: str, new_name: str) -> None: + """Rename a misc item, from `old_name` to `new_name`. + + Raises an error if the new analysis array name already exists. + + Args: + old_name: the current name for the misc item + new_name: the new name for the misc item + """ + if new_name in self.misc: + raise KeyError(f"Key `{new_name} already exists in misc items!") + + self.misc[new_name] = self.misc[old_name] + self.misc.pop(old_name) + def epoc_dataframe(self, include_epocs: FieldList = "all", include_meta: FieldList = "all") -> pd.DataFrame: """Produce a dataframe with epoc data and metadata. @@ -235,6 +435,84 @@ def scalar_dataframe(self, include_scalars: FieldList = "all", include_meta: Fie return pd.DataFrame(scalars) + def dlc_dataframe(self, id: Union[str, int] = 0) -> pd.DataFrame: + """Fetch DLC data as a pandas dataframe. + Args: + id: identifier to select which dlc data to use in the dataframe. If str is provided, will access that named dlc data. If int is provided, will use the data from that index position among the dlc data. + + By default index 0 will be accessed. + + Returns: + DataFrame with data from this session + """ + if isinstance(id, str): + df = pd.DataFrame(self.dlc[id]) + elif isinstance(id, int): + data_list = list(self.dlc.values()) + df = pd.DataFrame(data_list[id]) + else: + raise TypeError("Invalid `id` argument data type. Supported data identifier types are str and int.") + + df[df.select_dtypes(include='object').columns] = df.select_dtypes(include='object').map(decode_byteseq) + + return df + + def analysis_dataframe(self, include_analysis: FieldList = "all", include_meta: FieldList = "all") -> pd.DataFrame: + """Produce a dataframe with analysis data and metadata. + + Args: + include_analysis: list of analysis array names to include in the dataframe. Special str "all" is also accepted. + include_meta: list of metadata fields to include in the dataframe. Special str "all" is also accepted. + + Returns: + DataFrame with data from this session + """ + # determine metadata fields to include + if include_meta == "all": + meta = self.metadata + else: + meta = {k: v for k, v in self.metadata.items() if k in include_meta} + + # determine arrays to include + if include_analysis == "all": + analysis_names = list(self.analysis.keys()) + else: + analysis_names = [k for k in self.analysis.keys() if k in include_analysis] + + data = [] + for k, v in self.analysis.items(): + if k in analysis_names: + if len(v) == 1: + for value in v: + data.append({**meta, "metric": k, "obs": np.nan, "value": value}) + else: + obsn = 1 + for value in v: + data.append({**meta, "metric": k, "obs": int(obsn),"value": value}) + obsn += 1 + + df = pd.DataFrame(data) + + return df + + def misc_dataframe(self, id: str) -> pd.DataFrame: + """Fetch misc data as a pandas dataframe. + Args: + id: name of the misc item to load + + Returns: + DataFrame with data from this session + """ + if isinstance(self.misc[id], np.ndarray): + df = pd.DataFrame(self.misc[id]) + + else: + raise TypeError("Only numpy arrays can be loaded as a pandas DataFrame.") + + df[df.select_dtypes(include='object').columns] = df.select_dtypes(include='object').map(decode_byteseq) + + return df + def __eq__(self, value: object) -> bool: """Test this Session for equality to another Session. @@ -299,6 +577,9 @@ def _estimate_memory_use_itemized(self) -> dict[str, int]: **{f"signal.{sig.name}": sig._estimate_memory_use() for sig in self.signals.values()}, **{f"epocs.{k}": sys.getsizeof(k) + v.nbytes for k, v in self.epocs.items()}, **{f"scalars.{k}": sys.getsizeof(k) + v.nbytes for k, v in self.scalars.items()}, + **{f"dlc.{k}": sys.getsizeof(k) + v.nbytes for k, v in self.dlc.items()}, + **{f"analysis.{k}": sys.getsizeof(k) + v.nbytes for k, v in self.analysis.items()}, + **{f"misc.{k}": sys.getsizeof(k) + v.nbytes for k, v in self.misc.items()}, } def _estimate_memory_use(self) -> int: @@ -311,6 +592,7 @@ def save(self, path: str): Args: path: path where the data should be saved """ + with h5py.File(path, mode="w") as h5: # save name h5.create_dataset("/name", data=self.name) @@ -342,6 +624,54 @@ def save(self, path: str): for k, scalar in self.scalars.items(): h5.create_dataset(f"/scalars/{k}", data=scalar, compression="gzip") + # save dlc data + h5.create_group("/dlc") + for k, dlc in self.dlc.items(): + if isinstance(dlc, pd.DataFrame): + dlc.columns = dlc.columns.to_flat_index() + + for col in dlc.select_dtypes(include='object').columns: + dlc[col] = dlc[col].apply(lambda x: ','.join(map(str, x)) if isinstance(x, (list, tuple)) else x) + + for col in dlc.select_dtypes(include='object').columns: + dlc[col] = dlc[col].astype(np.bytes_).astype('S50') + # string_cols = [col for col in dlc.columns if isinstance(dlc[col].iloc[2], str)] + # for col in string_cols: + # dlc[col] = dlc[col].astype(np.bytes_).astype('S50') + nparray = dlc.to_records(index=False) + h5.create_dataset(f"/dlc/{k}", data=nparray) + elif isinstance(dlc, np.ndarray): + h5.create_dataset(f"/dlc/{k}", data=dlc) + + # save analysis data + h5.create_group("/analysis") + for k, analysis in self.analysis.items(): + h5.create_dataset(f"/analysis/{k}", data=np.atleast_1d(analysis).astype(float), compression="gzip") + # h5.create_dataset(f"/analysis/{k}", data=np.atleast_1d(analysis), dtype="f8", compression="gzip") + + # save misc data + h5.create_group("/misc") + for k, misc in self.misc.items(): + if isinstance(misc, pd.DataFrame): + misc.columns = misc.columns.to_flat_index() + misc.rename(mapper=remove_non_letters, axis=1) + + for col in dlc.select_dtypes(include='object').columns: + dlc[col] = dlc[col].apply(lambda x: ','.join(map(str, x)) if isinstance(x, (list, tuple)) else x) + + for col in dlc.select_dtypes(include='object').columns: + dlc[col] = dlc[col].astype(np.bytes_).astype('S50') + # string_cols = [col for col in dlc.columns if isinstance(dlc[col].iloc[2], str)] + # for col in string_cols: + # misc[col] = misc[col].astype(np.bytes_).astype('S50') + nparray = misc.to_records(index=False) + h5.create_dataset(f"/misc/{k}", data=nparray) + elif misc.dtype.names is not None: + h5.create_dataset(f"/misc/{k}", data=misc) + else: + h5.create_dataset(f"/misc/{k}", data=np.atleast_1d(misc).astype(float), compression="gzip") + # h5.create_dataset(f"/misc/{k}", data=np.atleast_1d(misc), dtype="f8", compression="gzip") + # save metadata meta_group = h5.create_group("/metadata") for k, v in self.metadata.items(): @@ -408,7 +738,11 @@ def load(cls, path: str) -> "Session": for signame in h5["/signals"].keys(): sig_group = h5[f"/signals/{signame}"] sig = Signal( - signame, sig_group["signal"][()], time=sig_group["time"][()], fs=sig_group.attrs["fs"], units=sig_group.attrs["units"] + signame, + sig_group["signal"][()], + time=sig_group["time"][()], + fs=sig_group.attrs["fs"], + units=sig_group.attrs["units"], ) for mark_name in sig_group["marks"].keys(): sig.marks[mark_name] = sig_group[f"marks/{mark_name}"][()] @@ -424,6 +758,21 @@ def load(cls, path: str) -> "Session": for scalar_name in h5["/scalars"].keys(): session.scalars[scalar_name] = h5[f"/scalars/{scalar_name}"][()] + # read dlc + if "/dlc" in h5: + for dlc_name in h5["/dlc"].keys(): + session.dlc[dlc_name] = h5[f"/dlc/{dlc_name}"][()] + + # read analysis + if "/analysis" in h5: + for analysis_name in h5["/analysis"].keys(): + session.analysis[analysis_name] = h5[f"/analysis/{analysis_name}"][()] + + # read misc + if "/misc" in h5: + for misc_name in h5["/misc"].keys(): + session.misc[misc_name] = h5[f"/misc/{misc_name}"][()] + # read metadata if "/metadata" in h5: for meta_name in h5[f"/metadata"].keys(): @@ -538,6 +887,26 @@ def rename_scalar(self, old_name: str, new_name: str) -> None: for item in self: item.rename_scalar(old_name, new_name) + def rename_analysis(self, old_name: str, new_name: str) -> None: + """Rename an analysis on each session in this collection. + + Args: + old_name: current name of the analysis + new_name: the new name for the analysis + """ + for item in self: + item.rename_analysis(old_name, new_name) + + def rename_misc(self, old_name: str, new_name: str) -> None: + """Rename a misc item on each session in this collection. + + Args: + old_name: current name of the misc item + new_name: the new name for the misc item + """ + for item in self: + item.rename_misc(old_name, new_name) + def filter(self, predicate: Callable[[Session], bool]) -> "SessionCollection": """Filter the items in this collection, returning a new `SessionCollection` containing sessions which pass `predicate`. @@ -647,6 +1016,16 @@ def get_signal(self, name: str) -> list[Signal]: List of Signals, each corresponding to a single session """ return [item.signals[name] for item in self if name in item.signals] + + def add_epoc(self, name: str, epoc_func: Callable[[Session], np.ndarray]) -> None: + """Apply an epoc function to each session in this collection, adding the results to each session's epocs attribute. + + Args: + name: Name of the new analysis data key + epoc_func: callable accepting a single session and returning a 1d numpy array + """ + for session in self: + session.add_epoc(epoc_func(session), name) def epoc_dataframe(self, include_epocs: FieldList = "all", include_meta: FieldList = "all") -> pd.DataFrame: """Produce a dataframe with epoc data and metadata across all the sessions in this collection. @@ -709,6 +1088,125 @@ def signal_dataframe(self, signal: str, include_meta: FieldList = "all") -> pd.D dfs.append(df) return pd.concat(dfs, ignore_index=True) + + def analysis_dataframe(self, include_analysis: FieldList = "all", include_meta: FieldList = "all") -> pd.DataFrame: + """Produce a dataframe with analysis data and metadata across all sessions in this collection. + + Args: + include_analysis: list of analysis names to include in the dataframe. Special str "all" is also accepted + include_meta: list of metadata fields to include in the dataframe. Special str "all" is also accepted + + Returns: + DataFrame with analysis data from across this collection + """ + dfs = [session.analysis_dataframe(include_analysis=include_analysis, include_meta=include_meta) for session in self] + return pd.concat(dfs).reset_index(drop=True) + + def add_analysis(self, name: str, analysis: Callable[[Session], np.ndarray]) -> None: + """Apply an analysis function to each session in this collection, adding the results to each session's analysis attribute. + + Args: + name: Name of the new analysis data key + analysis: callable accepting a single session and returning a 1d numpy array + """ + for session in self: + session.add_analysis(analysis(session), name) + + def run_analysis(self, analysis: Callable[[Session], Session], max_workers: Optional[int] = None) -> "SessionCollection": + """Run an analysis function on each session in this collection and return a new SessionCollection. + + Args: + analysis: callable accepting a single session with optional additional keyword arguments and returning a new session + max_workers: number of workers in the process pool for running analysis. If None, defaults to the number of CPUs on the machine. + + Returns: + a new `SessionCollection` containing results of `analysis` + + """ + futures: dict[Future[Session], str] = {} + context = multiprocessing.get_context("spawn") + max_tasks_per_child = 1 + + sc = SessionCollection() + + with ProcessPoolExecutor(max_workers=max_workers, mp_context=context, max_tasks_per_child=max_tasks_per_child) as executor: + + # iterate over all Sessions in the SessionCollection + + for s in self: + + # submit the task to the pool + f = executor.submit(analysis, s) + futures[f] = s.metadata["blockname"] + + # compile the new SessionCollection + for f in tqdm(as_completed(futures), total=len(futures)): + try: + sc.append(f.result()) + except Exception as e: + tqdm.write( + f'Problem running analysis at "{futures[f]}":\n{traceback.format_exc()}\nThe session will be missing from the resultant SessionCollection!\n' + ) + pass + + return sc + + def run_analysis_kw_test(self, analysis: Callable[[Session], Session], max_workers: Optional[int] = None, analysis_kwargs: Optional[dict] = None) -> "SessionCollection": + """Run an analysis function on each session in this collection and return a new SessionCollection. + + Args: + analysis: callable accepting a single session with optional additional keyword arguments and returning a new session + max_workers: number of workers in the process pool for running analysis. If None, defaults to the number of CPUs on the machine. + ana_kwargs: kwargs to pass to analysis callable + + Returns: + a new `SessionCollection` containing results of `analysis` + + """ + futures: dict[Future[Session], str] = {} + context = multiprocessing.get_context("spawn") + max_tasks_per_child = 1 + + _analysis_kwargs = {} + if analysis_kwargs is not None: + _analysis_kwargs.update(analysis_kwargs) + + sc = SessionCollection() + + with ProcessPoolExecutor(max_workers=max_workers, mp_context=context, max_tasks_per_child=max_tasks_per_child) as executor: + + # iterate over all Sessions in the SessionCollection + + for s in self: + + # submit the task to the pool + f = executor.submit(analysis, s, **_analysis_kwargs) + futures[f] = s.metadata["blockname"] + + # compile the new SessionCollection + for f in tqdm(as_completed(futures), total=len(futures)): + try: + sc.append(f.result()) + except Exception as e: + tqdm.write( + f'Problem running analysis at "{futures[f]}":\n{traceback.format_exc()}\nThe session will be missing from the resultant SessionCollection!\n' + ) + pass + + return sc + + def map(self, action: Callable[[Session], Session]) -> "SessionCollection": + """Apply a function to each session in this collection, returning a new collection with the results. + + Args: + action: callable accepting a single session and returning a new session + + Returns: + a new `SessionCollection` containing the results of `action` + """ + sc = type(self)(action(item) for item in self) + sc.__meta_meta.update(**copy.deepcopy(self.__meta_meta)) + return sc def aggregate_signals(self, name: str, method: Union[None, str, np.ufunc, Callable[[np.ndarray], np.ndarray]] = "median") -> Signal: """Aggregate signals across sessions in this collection for the signal name `name`. @@ -719,7 +1217,7 @@ def aggregate_signals(self, name: str, method: Union[None, str, np.ufunc, Callab Returns: Aggregated `Signal` - """ + """ signals = [s for s in self.get_signal(name) if s.nobs > 0] if len(signals) <= 0: raise ValueError("No signals were passed!") @@ -765,6 +1263,24 @@ def describe(self, as_str: bool = False) -> Union[str, None]: buffer += f'({v}) "{k}"\n' buffer += "\n" + dlcs = Counter([item for session in self for item in session.dlc.keys()]) + buffer += "DLC Structured Numpy Arrays present in data:\n" + for k, v in dlcs.items(): + buffer += f'({v}) "{k}"\n' + buffer += "\n" + + analysis = Counter([item for session in self for item in session.analysis.keys()]) + buffer += "Analysis arrays present in data:\n" + for k, v in analysis.items(): + buffer += f'({v}) "{k}"\n' + buffer += "\n" + + misc = Counter([item for session in self for item in session.misc.keys()]) + buffer += "Misc items present in data:\n" + for k, v in misc.items(): + buffer += f'({v}) "{k}"\n' + buffer += "\n" + if as_str: return buffer else: diff --git a/fptools/io/tdt_with_dlc.py b/fptools/io/tdt_with_dlc.py new file mode 100644 index 0000000..c69d119 --- /dev/null +++ b/fptools/io/tdt_with_dlc.py @@ -0,0 +1,159 @@ +import glob +import os +from typing import Optional +from pathlib import Path + +import pandas as pd + +import tdt + +from .common import DataTypeAdaptor +from .session import Session, Signal +from .tdt import TDT_EXCLUDE_STREAMS, TDTLoader + + +def has_neighboring_dlc_h5(tbk) -> bool: + """Checks if the TBK file has a neighboring H5 file that looks like a DLC data output. + + Args: + tbk (str): TBK file path to check + + Returns: + True if the TBK file has DLC neighbors, otherwise false + """ + dlc_id_substrs = ["DLC", "shuffle", "snapshot"] + neighbor_files = glob.glob(os.path.join(os.path.dirname(tbk), "*")) + neighbor_files = [file for file in neighbor_files if file.endswith(".h5")] + + dlc_files = [file for file in neighbor_files if all(sub in file for sub in dlc_id_substrs)] + + if len(dlc_files) > 0: + return True + + return False + + +class FindTDTDLCBlocks: + + def __init__(self, model_name: Optional[list[str]] = None, filtered_only: bool = True): + """Initialize this TDT-DLC Data Locator. + + Args: + model_name: If provided, only look for DLC files with that model name(s), If None, load all files that look like DLC data + filtered_only: If true, only load filtered DLC data, otherwise, load any DLC data + """ + self.model_name = model_name + self.filtered_only = filtered_only + + def __call__(self, path: str): + """Data Locator for TDT blocks with DLC data. + + Given a path to a directory, will search that path recursively for TDT blocks that contain DLC output files in .h5 format. + + Args: + path: path to search for TDT blocks with DLC data + + Returns: + list of DataTypeAdaptor, each adaptor corresponding to one session, of data to be loaded + """ + tbk_files = glob.glob(os.path.join(path, "**/*.[tT][bB][kK]"), recursive=True) + # tbk_files = [file for file in tbk_files if has_neighboring_dlc_h5(file)] + + items_out = [] + for tbk in tbk_files: + adapt = DataTypeAdaptor() + adapt.path = os.path.dirname(tbk) # the directory for the block + adapt.name = os.path.basename(adapt.path) # the name of the directory + adapt.loaders.append(TDTLoader(exclude_streams=TDT_EXCLUDE_STREAMS)) + adapt.loaders.append(DLCLoader(model_name=self.model_name, filtered_only=self.filtered_only)) + items_out.append(adapt) + + return items_out + + +# def find_tdt_w_dlc_blocks(path: str) -> list[DataTypeAdaptor]: +# """Data Locator for TDT blocks with DLC data. + +# Given a path to a directory, will search that path recursively for TDT blocks that contain DLC output files in .h5 format. + +# Args: +# path: path to search for TDT blocks with DLC data + +# Returns: +# list of DataTypeAdaptor, each adaptor corresponding to one session, of data to be loaded +# """ + +# tbk_files = glob.glob(os.path.join(path, "**/*.[tT][bB][kK]"), recursive=True) +# tbk_files = [file for file in tbk_files if has_neighboring_dlc_h5(file)] + +# items_out = [] +# for tbk in tbk_files: +# adapt = DataTypeAdaptor() +# adapt.path = os.path.dirname(tbk) # the directory for the block +# adapt.name = os.path.basename(adapt.path) # the name of the directory +# adapt.loaders.append(TDTLoader(exclude_streams=TDT_EXCLUDE_STREAMS)) +# adapt.loaders.append(DLCLoader()) +# items_out.append(adapt) + +# return items_out + + +class DLCLoader: + def __init__(self, model_name: Optional[list[str]] = None, filtered_only: bool = True) -> None: + """Initialize this DLCLoader.""" + self.model_name = model_name + self.filtered_only = filtered_only + + def __call__(self, session: Session, path: str) -> Session: + """Data Loader for DLC .h5 files in TDT blocks. + + Args: + session: the session for data to be loaded into + path: path to a TDT block folder containing DLC data + + Returns: + Session object with data added + """ + # find the dlc .h5 file + if self.model_name is None: + pattern = os.path.join(path, f"*DLC*.h5") + files = glob.glob(pattern) + + if len(files) <= 0: + raise FileNotFoundError(f"Could not find any DLC files in block {session.name}!") + + for file in files: + df = pd.read_hdf(file) + key = f"{df.columns[0][0]}" + df.columns = df.columns.droplevel(level=0).to_flat_index() + nparray = df.to_records(index=False) + + if "_filtered" in file: + key += "_filtered" + session.dlc[key] = nparray + else: + if not self.filtered_only: + key += "_unfiltered" + session.dlc[key] = nparray + else: + for mn in self.model_name: + if self.filtered_only: + pattern = os.path.join(path, f"*{mn}*_filtered.h5") + else: + pattern = os.path.join(path, f"*{mn}*.h5") + files = glob.glob(pattern) + if len(files) <= 0: + raise FileNotFoundError(f'Could not find any DLC files for model "{mn}" in block {session.name}!') + + for file in files: + df = pd.read_hdf(file) + df.columns = df.columns.droplevel(level=0).to_flat_index() + nparray = df.to_records(index=False) + key = f"{mn}" + if "_filtered" in file: + key += "_filtered" + else: + key += "_unfiltered" + session.dlc[key] = nparray + + return session diff --git a/fptools/preprocess/steps/dlc_interp.py b/fptools/preprocess/steps/dlc_interp.py new file mode 100644 index 0000000..53e4544 --- /dev/null +++ b/fptools/preprocess/steps/dlc_interp.py @@ -0,0 +1,19 @@ +from matplotlib.axes import Axes +from fptools.io.session import Session +from fptools.preprocess.common import ProcessorThatPlots + + +class DLCInterpolation(ProcessorThatPlots): + """A `Processor` that interpolates missing frames in dlc data.""" + + def __init__( + self, + ): + """Initialize this Processor.""" + pass + + def __call__(self, session: Session) -> Session: + pass + + def plot(self, session: Session, ax: Axes): + pass diff --git a/fptools/viz/signal.py b/fptools/viz/signal.py index b89ce83..3b5eb4d 100644 --- a/fptools/viz/signal.py +++ b/fptools/viz/signal.py @@ -11,7 +11,7 @@ from typing import Any, Literal, Optional, Union from fptools.io import Signal, SessionCollection -from .common import Palette, get_colormap +from fptools.viz.common import Palette, get_colormap def plot_signal( diff --git a/fptools/viz/video.py b/fptools/viz/video.py new file mode 100644 index 0000000..d30c8ad --- /dev/null +++ b/fptools/viz/video.py @@ -0,0 +1,30 @@ +import os +from typing import Union +import cv2 +import numpy as np + + +def get_frame_image(video: str, frame_idx: int) -> Union[np.ndarray, None]: + """Get a frame from a video using openCV. + + Args: + video: string path to the video file + frame_idx: integer index of the frame to grab + + Returns: + numpy array of the image at `frame_idx`, Or None if the frame could not be retrieved + """ + cap = cv2.VideoCapture(video) + if not cap.isOpened(): + raise Exception(f"Could not open video file: {video}") + + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) + + ret, frame = cap.read() + + if not ret: + print("Error: Could not read frame.") + return None + + cap.release() + return frame diff --git a/pyproject.toml b/pyproject.toml index fd388a1..18b902a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ authors = [ {name = "Joshua K. Thackray", email = "thackray@rutgers.edu"}, ] description="Collection of tools for working with fiber photometry data." -requires-python = ">=3.12,<=3.13" +requires-python = ">=3.12,<=3.13.2" keywords = ["fiber photometry", "behavior"] license = {text = "BSD-3-Clause"} classifiers = [