Skip to content

Conversation

@schewskone
Copy link
Collaborator

Implemented support for compressed screen formats

  • Added EncodedImageTrial and EncodedVideoTrial which feature get_data_() methods that support encoded formats.
  • Extended screen tests and screen data generation to feature encoded data in mp4 and jpeg format.
  • Thinned requirements.txt and added ffmpeg installation to CI workflow.

Right now the VideoDecoder decodes the entire video and returns that to the interpolation function. This can be optimized with the help of sliced decoding which would require changes to the interpolation method. @pollytur and I decided we should discuss first and implement it later on if deemed necessary.

@pollytur pollytur requested a review from Copilot July 3, 2025 10:07
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for compressed screen data by extending trials to handle encoded images and videos, updating tests and data generation to work with JPEG/MP4, and adjusting interpolation interfaces and CI.

  • Introduce EncodedImageTrial and EncodedVideoTrial with compressed data loaders
  • Enhance test utilities and screen interpolation tests for both encoded and raw formats
  • Update interpolation API to return only data, adjust downstream consumers, and add FFmpeg to CI

Reviewed Changes

Copilot reviewed 8 out of 9 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
tests/test_sequence_interpolator.py Update tests for new single-array return value of interpolate
tests/test_screen_interpolator.py Add parameterized tests for encoded vs. raw screen data
tests/create_screen_data.py Extend data generation to save JPEG/MP4 and emit metadata
experanto/interpolators.py Remove valid returns, introduce image_names flag, add encoded trials
experanto/experiment.py Adjust interpolate to drop valid output
experanto/datasets.py Update dataset pipelines to match new return signature
configs/default.yaml Add image_names configuration
.github/workflows/test.yml Install FFmpeg for encoding dependencies
Comments suppressed due to low confidence (4)

experanto/interpolators.py:152

  • The return type annotation still indicates a tuple, but the method now returns only a single array. Update the type hint and docstring to reflect the new signature.
    def interpolate(self, times: np.ndarray) -> tuple[np.ndarray, np.ndarray]:

experanto/interpolators.py:425

  • [nitpick] Using format shadows the built-in Python function. Consider renaming this variable to file_format to avoid confusion.
            format = metadata.get("file_format")

tests/create_screen_data.py:14

  • The new encoded parameter isn't documented. Please update the function docstring to explain what encoded does and what formats it controls.
def create_screen_data(

tests/create_screen_data.py:118

  • The code calls shutil.rmtree but shutil is not imported. Add import shutil at the top of the module.
        shutil.rmtree(SCREEN_ROOT)

Copy link
Contributor

@pollytur pollytur Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess a higher level comment here is if we want to simulate any other codecs than jpeg and mp4

Was there anything else in Allen data?

Copy link
Contributor

@pollytur pollytur Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for me to check

  • burgeon data -> spy only
  • IBL data -> .parquet as its timeseries data only (link)
  • ask Caio -> no

@pollytur
Copy link
Contributor

@schewskone needs to check which codes are supported and if num_workers > =1 work

…l video decoding, VideoDecoder objects are now shared across trials
Copilot AI review requested due to automatic review settings January 27, 2026 19:58
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 10 changed files in this pull request and generated 21 comments.

Comments suppressed due to low confidence (3)

experanto/interpolators.py:824

  • The type hint for data_file_name parameter is str, but this is being passed a Path object from the _parse_trials method (line 446). OpenCV's VideoCapture can handle both str and Path, but for consistency and correctness, either the type hint should be updated to accept Union[str, Path], or the Path should be converted to str at the call site.
        self.first_frame_idx = first_frame_idx
        self.num_frames = num_frames
        self._cached_data = None
        self._cache_data = cache_data
        if self._cache_data:
            self._cached_data = self.get_data_()

experanto/interpolators.py:793

  • The TimeIntervalInterpolator class is defined twice in this file (lines 614-703 and lines 705-793). This is a critical bug that will cause the second definition to overwrite the first. One of these duplicate definitions should be removed.
                        interpolation=cv2.INTER_AREA,
                    )
                )
                for frame in data
            ],
            axis=0,
        )


class TimeIntervalInterpolator(Interpolator):
    def __init__(self, root_folder: str, cache_data: bool = False, **kwargs):
        super().__init__(root_folder)
        self.cache_data = cache_data

        meta = self.load_meta()
        self.meta_labels = meta["labels"]
        self.start_time = meta["start_time"]
        self.end_time = meta["end_time"]
        self.valid_interval = TimeInterval(self.start_time, self.end_time)

        if self.cache_data:
            self.labeled_intervals = {
                label: np.load(self.root_folder / filename)
                for label, filename in self.meta_labels.items()
            }

    def interpolate(self, times: np.ndarray) -> np.ndarray:
        """
        Interpolate time intervals for labeled events.

        Given a set of time points and a set of labeled intervals (defined in the
        `meta.yml` file), this method returns a boolean array indicating, for each
        time point, whether it falls within any interval for each label.

        The method uses half-open intervals [start, end), where a timestamp t is
        considered to fall within an interval if start <= t < end. This means the
        start time is inclusive and the end time is exclusive.

        Parameters
        ----------
        times : np.ndarray
            Array of time points to be checked against the labeled intervals.

        Returns
        -------
        out : np.ndarray of bool, shape (len(valid_times), n_labels)
            Boolean array where each row corresponds to a valid time point and each
            column corresponds to a label. `out[i, j]` is True if the i-th valid
            time falls within any interval for the j-th label, and False otherwise.

        Notes
        -----
        - The labels and their corresponding intervals are defined in the `meta.yml`
          file under the `labels` key. Each label points to a `.npy` file containing
          an array of shape (n, 2), where each row is a [start, end) time interval.
        - Typical labels might include 'train', 'validation', 'test', 'saccade',
          'gaze', or 'target'.
        - Only time points within the valid interval (as defined by start_time and
          end_time in meta.yml) are considered; others are filtered out.
        - Intervals where start > end are considered invalid and will trigger a
          warning.
        """
        valid = self.valid_times(times)
        valid_times = times[valid]

        n_labels = len(self.meta_labels)
        n_times = len(valid_times)

        if n_times == 0:
            warnings.warn(
                "TimeIntervalInterpolator returns an empty array, no valid times queried."
            )
            return np.empty((0, n_labels), dtype=bool)

        out = np.zeros((n_times, n_labels), dtype=bool)
        for i, (label, filename) in enumerate(self.meta_labels.items()):
            if self.cache_data:
                intervals = self.labeled_intervals[label]
            else:
                intervals = np.load(self.root_folder / filename, allow_pickle=True)

            if len(intervals) == 0:
                warnings.warn(
                    f"TimeIntervalInterpolator found no intervals for label: {label}"
                )
                continue

            for start, end in intervals:
                if start > end:
                    warnings.warn(
                        f"Invalid interval found for label: {label}, interval: ({start}, {end})"
                    )
                    continue
                # Half-open interval [start, end): inclusive start, exclusive end
                mask = (valid_times >= start) & (valid_times < end)
                out[mask, i] = True

        return out


class TimeIntervalInterpolator(Interpolator):
    def __init__(self, root_folder: str, cache_data: bool = False, **kwargs):
        super().__init__(root_folder)
        self.cache_data = cache_data

        meta = self.load_meta()
        self.meta_labels = meta["labels"]
        self.start_time = meta["start_time"]
        self.end_time = meta["end_time"]
        self.valid_interval = TimeInterval(self.start_time, self.end_time)

        if self.cache_data:
            self.labeled_intervals = {
                label: np.load(self.root_folder / filename)
                for label, filename in self.meta_labels.items()
            }

    def interpolate(self, times: np.ndarray) -> np.ndarray:
        """
        Interpolate time intervals for labeled events.

        Given a set of time points and a set of labeled intervals (defined in the
        `meta.yml` file), this method returns a boolean array indicating, for each
        time point, whether it falls within any interval for each label.

        The method uses half-open intervals [start, end), where a timestamp t is
        considered to fall within an interval if start <= t < end. This means the
        start time is inclusive and the end time is exclusive.

        Parameters
        ----------
        times : np.ndarray
            Array of time points to be checked against the labeled intervals.

        Returns
        -------
        out : np.ndarray of bool, shape (len(valid_times), n_labels)
            Boolean array where each row corresponds to a valid time point and each
            column corresponds to a label. `out[i, j]` is True if the i-th valid
            time falls within any interval for the j-th label, and False otherwise.

        Notes
        -----
        - The labels and their corresponding intervals are defined in the `meta.yml`
          file under the `labels` key. Each label points to a `.npy` file containing
          an array of shape (n, 2), where each row is a [start, end) time interval.
        - Typical labels might include 'train', 'validation', 'test', 'saccade',
          'gaze', or 'target'.
        - Only time points within the valid interval (as defined by start_time and
          end_time in meta.yml) are considered; others are filtered out.
        - Intervals where start > end are considered invalid and will trigger a
          warning.
        """
        valid = self.valid_times(times)
        valid_times = times[valid]

        n_labels = len(self.meta_labels)
        n_times = len(valid_times)

        if n_times == 0:
            warnings.warn(
                "TimeIntervalInterpolator returns an empty array, no valid times queried."
            )
            return np.empty((0, n_labels), dtype=bool)

        out = np.zeros((n_times, n_labels), dtype=bool)
        for i, (label, filename) in enumerate(self.meta_labels.items()):
            if self.cache_data:
                intervals = self.labeled_intervals[label]
            else:
                intervals = np.load(self.root_folder / filename, allow_pickle=True)

            if len(intervals) == 0:
                warnings.warn(
                    f"TimeIntervalInterpolator found no intervals for label: {label}"
                )
                continue

            for start, end in intervals:
                if start > end:

experanto/interpolators.py:815

        self.data_file_name = data_file_name

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copilot AI review requested due to automatic review settings January 27, 2026 20:23
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 10 changed files in this pull request and generated 22 comments.

Comments suppressed due to low confidence (2)

experanto/interpolators.py:802

  • Duplicate class definition detected. The class TimeIntervalInterpolator is defined twice in this file (once at line 623 and again at line 714). This will cause the second definition to overwrite the first one, which is likely unintended. The duplicate should be removed.
class TimeIntervalInterpolator(Interpolator):
    def __init__(self, root_folder: str, cache_data: bool = False, **kwargs):
        super().__init__(root_folder)
        self.cache_data = cache_data

        meta = self.load_meta()
        self.meta_labels = meta["labels"]
        self.start_time = meta["start_time"]
        self.end_time = meta["end_time"]
        self.valid_interval = TimeInterval(self.start_time, self.end_time)

        if self.cache_data:
            self.labeled_intervals = {
                label: np.load(self.root_folder / filename)
                for label, filename in self.meta_labels.items()
            }

    def interpolate(self, times: np.ndarray) -> np.ndarray:
        """
        Interpolate time intervals for labeled events.

        Given a set of time points and a set of labeled intervals (defined in the
        `meta.yml` file), this method returns a boolean array indicating, for each
        time point, whether it falls within any interval for each label.

        The method uses half-open intervals [start, end), where a timestamp t is
        considered to fall within an interval if start <= t < end. This means the
        start time is inclusive and the end time is exclusive.

        Parameters
        ----------
        times : np.ndarray
            Array of time points to be checked against the labeled intervals.

        Returns
        -------
        out : np.ndarray of bool, shape (len(valid_times), n_labels)
            Boolean array where each row corresponds to a valid time point and each
            column corresponds to a label. `out[i, j]` is True if the i-th valid
            time falls within any interval for the j-th label, and False otherwise.

        Notes
        -----
        - The labels and their corresponding intervals are defined in the `meta.yml`
          file under the `labels` key. Each label points to a `.npy` file containing
          an array of shape (n, 2), where each row is a [start, end) time interval.
        - Typical labels might include 'train', 'validation', 'test', 'saccade',
          'gaze', or 'target'.
        - Only time points within the valid interval (as defined by start_time and
          end_time in meta.yml) are considered; others are filtered out.
        - Intervals where start > end are considered invalid and will trigger a
          warning.
        """
        valid = self.valid_times(times)
        valid_times = times[valid]

        n_labels = len(self.meta_labels)
        n_times = len(valid_times)

        if n_times == 0:
            warnings.warn(
                "TimeIntervalInterpolator returns an empty array, no valid times queried."
            )
            return np.empty((0, n_labels), dtype=bool)

        out = np.zeros((n_times, n_labels), dtype=bool)
        for i, (label, filename) in enumerate(self.meta_labels.items()):
            if self.cache_data:
                intervals = self.labeled_intervals[label]
            else:
                intervals = np.load(self.root_folder / filename, allow_pickle=True)

            if len(intervals) == 0:
                warnings.warn(
                    f"TimeIntervalInterpolator found no intervals for label: {label}"
                )
                continue

            for start, end in intervals:
                if start > end:
                    warnings.warn(
                        f"Invalid interval found for label: {label}, interval: ({start}, {end})"
                    )
                    continue
                # Half-open interval [start, end): inclusive start, exclusive end
                mask = (valid_times >= start) & (valid_times < end)
                out[mask, i] = True

        return out

experanto/interpolators.py:824

            self._cached_data = self.get_data_()

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copilot AI review requested due to automatic review settings January 27, 2026 20:41
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 10 changed files in this pull request and generated 7 comments.

Comments suppressed due to low confidence (2)

experanto/interpolators.py:802

  • The TimeIntervalInterpolator class is defined twice in this file. The second definition (lines 714-802) is a complete duplicate of the first one (lines 623-711). This will cause the second definition to override the first, but having duplicate code is confusing and should be removed. Only one definition should remain.
class TimeIntervalInterpolator(Interpolator):
    def __init__(self, root_folder: str, cache_data: bool = False, **kwargs):
        super().__init__(root_folder)
        self.cache_data = cache_data

        meta = self.load_meta()
        self.meta_labels = meta["labels"]
        self.start_time = meta["start_time"]
        self.end_time = meta["end_time"]
        self.valid_interval = TimeInterval(self.start_time, self.end_time)

        if self.cache_data:
            self.labeled_intervals = {
                label: np.load(self.root_folder / filename)
                for label, filename in self.meta_labels.items()
            }

    def interpolate(self, times: np.ndarray) -> np.ndarray:
        """
        Interpolate time intervals for labeled events.

        Given a set of time points and a set of labeled intervals (defined in the
        `meta.yml` file), this method returns a boolean array indicating, for each
        time point, whether it falls within any interval for each label.

        The method uses half-open intervals [start, end), where a timestamp t is
        considered to fall within an interval if start <= t < end. This means the
        start time is inclusive and the end time is exclusive.

        Parameters
        ----------
        times : np.ndarray
            Array of time points to be checked against the labeled intervals.

        Returns
        -------
        out : np.ndarray of bool, shape (len(valid_times), n_labels)
            Boolean array where each row corresponds to a valid time point and each
            column corresponds to a label. `out[i, j]` is True if the i-th valid
            time falls within any interval for the j-th label, and False otherwise.

        Notes
        -----
        - The labels and their corresponding intervals are defined in the `meta.yml`
          file under the `labels` key. Each label points to a `.npy` file containing
          an array of shape (n, 2), where each row is a [start, end) time interval.
        - Typical labels might include 'train', 'validation', 'test', 'saccade',
          'gaze', or 'target'.
        - Only time points within the valid interval (as defined by start_time and
          end_time in meta.yml) are considered; others are filtered out.
        - Intervals where start > end are considered invalid and will trigger a
          warning.
        """
        valid = self.valid_times(times)
        valid_times = times[valid]

        n_labels = len(self.meta_labels)
        n_times = len(valid_times)

        if n_times == 0:
            warnings.warn(
                "TimeIntervalInterpolator returns an empty array, no valid times queried."
            )
            return np.empty((0, n_labels), dtype=bool)

        out = np.zeros((n_times, n_labels), dtype=bool)
        for i, (label, filename) in enumerate(self.meta_labels.items()):
            if self.cache_data:
                intervals = self.labeled_intervals[label]
            else:
                intervals = np.load(self.root_folder / filename, allow_pickle=True)

            if len(intervals) == 0:
                warnings.warn(
                    f"TimeIntervalInterpolator found no intervals for label: {label}"
                )
                continue

            for start, end in intervals:
                if start > end:
                    warnings.warn(
                        f"Invalid interval found for label: {label}, interval: ({start}, {end})"
                    )
                    continue
                # Half-open interval [start, end): inclusive start, exclusive end
                mask = (valid_times >= start) & (valid_times < end)
                out[mask, i] = True

        return out

experanto/interpolators.py:824

            self._cached_data = self.get_data_()

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copilot AI review requested due to automatic review settings January 27, 2026 21:33
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

Copilot AI review requested due to automatic review settings January 27, 2026 21:44
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

Copilot AI review requested due to automatic review settings January 27, 2026 21:50
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 8 out of 10 changed files in this pull request and generated 16 comments.

Comments suppressed due to low confidence (1)

experanto/interpolators.py:826

            self._cached_data = self.get_data_()

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


def get_data_(self) -> np.array:
"""Override base implementation to load compressed images"""
img = cv2.imread(self.data_file_name)
Copy link

Copilot AI Jan 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cv2.imread requires a string path, but data_file_name might be a Path object (as shown in line 448 where Path objects are created). This could cause cv2.imread to fail. Convert to string: img = cv2.imread(str(self.data_file_name)).

Suggested change
img = cv2.imread(self.data_file_name)
img = cv2.imread(str(self.data_file_name))

Copilot uses AI. Check for mistakes.
@schewskone schewskone requested a review from pollytur January 29, 2026 15:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants