-
Notifications
You must be signed in to change notification settings - Fork 19
Data compression for screen data #81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Data compression for screen data #81
Conversation
…ges and mp4 videos
… for compatibility with allen-exporter and future datasets to avoid duplicate files
Adding small fix for ToTensor issue to decoder PR because seperate PR is unnessesary
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for compressed screen data by extending trials to handle encoded images and videos, updating tests and data generation to work with JPEG/MP4, and adjusting interpolation interfaces and CI.
- Introduce EncodedImageTrial and EncodedVideoTrial with compressed data loaders
- Enhance test utilities and screen interpolation tests for both encoded and raw formats
- Update interpolation API to return only data, adjust downstream consumers, and add FFmpeg to CI
Reviewed Changes
Copilot reviewed 8 out of 9 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_sequence_interpolator.py | Update tests for new single-array return value of interpolate |
| tests/test_screen_interpolator.py | Add parameterized tests for encoded vs. raw screen data |
| tests/create_screen_data.py | Extend data generation to save JPEG/MP4 and emit metadata |
| experanto/interpolators.py | Remove valid returns, introduce image_names flag, add encoded trials |
| experanto/experiment.py | Adjust interpolate to drop valid output |
| experanto/datasets.py | Update dataset pipelines to match new return signature |
| configs/default.yaml | Add image_names configuration |
| .github/workflows/test.yml | Install FFmpeg for encoding dependencies |
Comments suppressed due to low confidence (4)
experanto/interpolators.py:152
- The return type annotation still indicates a tuple, but the method now returns only a single array. Update the type hint and docstring to reflect the new signature.
def interpolate(self, times: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
experanto/interpolators.py:425
- [nitpick] Using
formatshadows the built-in Python function. Consider renaming this variable tofile_formatto avoid confusion.
format = metadata.get("file_format")
tests/create_screen_data.py:14
- The new
encodedparameter isn't documented. Please update the function docstring to explain whatencodeddoes and what formats it controls.
def create_screen_data(
tests/create_screen_data.py:118
- The code calls
shutil.rmtreebutshutilis not imported. Addimport shutilat the top of the module.
shutil.rmtree(SCREEN_ROOT)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess a higher level comment here is if we want to simulate any other codecs than jpeg and mp4
Was there anything else in Allen data?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for me to check
- burgeon data -> spy only
- IBL data -> .parquet as its timeseries data only (link)
- ask Caio -> no
|
@schewskone needs to check which codes are supported and if |
…l video decoding, VideoDecoder objects are now shared across trials
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 8 out of 10 changed files in this pull request and generated 21 comments.
Comments suppressed due to low confidence (3)
experanto/interpolators.py:824
- The type hint for data_file_name parameter is
str, but this is being passed aPathobject from the _parse_trials method (line 446). OpenCV's VideoCapture can handle both str and Path, but for consistency and correctness, either the type hint should be updated to accept Union[str, Path], or the Path should be converted to str at the call site.
self.first_frame_idx = first_frame_idx
self.num_frames = num_frames
self._cached_data = None
self._cache_data = cache_data
if self._cache_data:
self._cached_data = self.get_data_()
experanto/interpolators.py:793
- The TimeIntervalInterpolator class is defined twice in this file (lines 614-703 and lines 705-793). This is a critical bug that will cause the second definition to overwrite the first. One of these duplicate definitions should be removed.
interpolation=cv2.INTER_AREA,
)
)
for frame in data
],
axis=0,
)
class TimeIntervalInterpolator(Interpolator):
def __init__(self, root_folder: str, cache_data: bool = False, **kwargs):
super().__init__(root_folder)
self.cache_data = cache_data
meta = self.load_meta()
self.meta_labels = meta["labels"]
self.start_time = meta["start_time"]
self.end_time = meta["end_time"]
self.valid_interval = TimeInterval(self.start_time, self.end_time)
if self.cache_data:
self.labeled_intervals = {
label: np.load(self.root_folder / filename)
for label, filename in self.meta_labels.items()
}
def interpolate(self, times: np.ndarray) -> np.ndarray:
"""
Interpolate time intervals for labeled events.
Given a set of time points and a set of labeled intervals (defined in the
`meta.yml` file), this method returns a boolean array indicating, for each
time point, whether it falls within any interval for each label.
The method uses half-open intervals [start, end), where a timestamp t is
considered to fall within an interval if start <= t < end. This means the
start time is inclusive and the end time is exclusive.
Parameters
----------
times : np.ndarray
Array of time points to be checked against the labeled intervals.
Returns
-------
out : np.ndarray of bool, shape (len(valid_times), n_labels)
Boolean array where each row corresponds to a valid time point and each
column corresponds to a label. `out[i, j]` is True if the i-th valid
time falls within any interval for the j-th label, and False otherwise.
Notes
-----
- The labels and their corresponding intervals are defined in the `meta.yml`
file under the `labels` key. Each label points to a `.npy` file containing
an array of shape (n, 2), where each row is a [start, end) time interval.
- Typical labels might include 'train', 'validation', 'test', 'saccade',
'gaze', or 'target'.
- Only time points within the valid interval (as defined by start_time and
end_time in meta.yml) are considered; others are filtered out.
- Intervals where start > end are considered invalid and will trigger a
warning.
"""
valid = self.valid_times(times)
valid_times = times[valid]
n_labels = len(self.meta_labels)
n_times = len(valid_times)
if n_times == 0:
warnings.warn(
"TimeIntervalInterpolator returns an empty array, no valid times queried."
)
return np.empty((0, n_labels), dtype=bool)
out = np.zeros((n_times, n_labels), dtype=bool)
for i, (label, filename) in enumerate(self.meta_labels.items()):
if self.cache_data:
intervals = self.labeled_intervals[label]
else:
intervals = np.load(self.root_folder / filename, allow_pickle=True)
if len(intervals) == 0:
warnings.warn(
f"TimeIntervalInterpolator found no intervals for label: {label}"
)
continue
for start, end in intervals:
if start > end:
warnings.warn(
f"Invalid interval found for label: {label}, interval: ({start}, {end})"
)
continue
# Half-open interval [start, end): inclusive start, exclusive end
mask = (valid_times >= start) & (valid_times < end)
out[mask, i] = True
return out
class TimeIntervalInterpolator(Interpolator):
def __init__(self, root_folder: str, cache_data: bool = False, **kwargs):
super().__init__(root_folder)
self.cache_data = cache_data
meta = self.load_meta()
self.meta_labels = meta["labels"]
self.start_time = meta["start_time"]
self.end_time = meta["end_time"]
self.valid_interval = TimeInterval(self.start_time, self.end_time)
if self.cache_data:
self.labeled_intervals = {
label: np.load(self.root_folder / filename)
for label, filename in self.meta_labels.items()
}
def interpolate(self, times: np.ndarray) -> np.ndarray:
"""
Interpolate time intervals for labeled events.
Given a set of time points and a set of labeled intervals (defined in the
`meta.yml` file), this method returns a boolean array indicating, for each
time point, whether it falls within any interval for each label.
The method uses half-open intervals [start, end), where a timestamp t is
considered to fall within an interval if start <= t < end. This means the
start time is inclusive and the end time is exclusive.
Parameters
----------
times : np.ndarray
Array of time points to be checked against the labeled intervals.
Returns
-------
out : np.ndarray of bool, shape (len(valid_times), n_labels)
Boolean array where each row corresponds to a valid time point and each
column corresponds to a label. `out[i, j]` is True if the i-th valid
time falls within any interval for the j-th label, and False otherwise.
Notes
-----
- The labels and their corresponding intervals are defined in the `meta.yml`
file under the `labels` key. Each label points to a `.npy` file containing
an array of shape (n, 2), where each row is a [start, end) time interval.
- Typical labels might include 'train', 'validation', 'test', 'saccade',
'gaze', or 'target'.
- Only time points within the valid interval (as defined by start_time and
end_time in meta.yml) are considered; others are filtered out.
- Intervals where start > end are considered invalid and will trigger a
warning.
"""
valid = self.valid_times(times)
valid_times = times[valid]
n_labels = len(self.meta_labels)
n_times = len(valid_times)
if n_times == 0:
warnings.warn(
"TimeIntervalInterpolator returns an empty array, no valid times queried."
)
return np.empty((0, n_labels), dtype=bool)
out = np.zeros((n_times, n_labels), dtype=bool)
for i, (label, filename) in enumerate(self.meta_labels.items()):
if self.cache_data:
intervals = self.labeled_intervals[label]
else:
intervals = np.load(self.root_folder / filename, allow_pickle=True)
if len(intervals) == 0:
warnings.warn(
f"TimeIntervalInterpolator found no intervals for label: {label}"
)
continue
for start, end in intervals:
if start > end:
experanto/interpolators.py:815
- This call to ScreenTrial.get_data_ in an initialization method is overridden by EncodedImageTrial.get_data_.
This call to ScreenTrial.get_data_ in an initialization method is overridden by EncodedVideoTrial.get_data_.
This call to ScreenTrial.get_data_ in an initialization method is overridden by BlankTrial.get_data_.
This call to ScreenTrial.get_data_ in an initialization method is overridden by InvalidTrial.get_data_.
self.data_file_name = data_file_name
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
9f12d84 to
16a2a14
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 8 out of 10 changed files in this pull request and generated 22 comments.
Comments suppressed due to low confidence (2)
experanto/interpolators.py:802
- Duplicate class definition detected. The class
TimeIntervalInterpolatoris defined twice in this file (once at line 623 and again at line 714). This will cause the second definition to overwrite the first one, which is likely unintended. The duplicate should be removed.
class TimeIntervalInterpolator(Interpolator):
def __init__(self, root_folder: str, cache_data: bool = False, **kwargs):
super().__init__(root_folder)
self.cache_data = cache_data
meta = self.load_meta()
self.meta_labels = meta["labels"]
self.start_time = meta["start_time"]
self.end_time = meta["end_time"]
self.valid_interval = TimeInterval(self.start_time, self.end_time)
if self.cache_data:
self.labeled_intervals = {
label: np.load(self.root_folder / filename)
for label, filename in self.meta_labels.items()
}
def interpolate(self, times: np.ndarray) -> np.ndarray:
"""
Interpolate time intervals for labeled events.
Given a set of time points and a set of labeled intervals (defined in the
`meta.yml` file), this method returns a boolean array indicating, for each
time point, whether it falls within any interval for each label.
The method uses half-open intervals [start, end), where a timestamp t is
considered to fall within an interval if start <= t < end. This means the
start time is inclusive and the end time is exclusive.
Parameters
----------
times : np.ndarray
Array of time points to be checked against the labeled intervals.
Returns
-------
out : np.ndarray of bool, shape (len(valid_times), n_labels)
Boolean array where each row corresponds to a valid time point and each
column corresponds to a label. `out[i, j]` is True if the i-th valid
time falls within any interval for the j-th label, and False otherwise.
Notes
-----
- The labels and their corresponding intervals are defined in the `meta.yml`
file under the `labels` key. Each label points to a `.npy` file containing
an array of shape (n, 2), where each row is a [start, end) time interval.
- Typical labels might include 'train', 'validation', 'test', 'saccade',
'gaze', or 'target'.
- Only time points within the valid interval (as defined by start_time and
end_time in meta.yml) are considered; others are filtered out.
- Intervals where start > end are considered invalid and will trigger a
warning.
"""
valid = self.valid_times(times)
valid_times = times[valid]
n_labels = len(self.meta_labels)
n_times = len(valid_times)
if n_times == 0:
warnings.warn(
"TimeIntervalInterpolator returns an empty array, no valid times queried."
)
return np.empty((0, n_labels), dtype=bool)
out = np.zeros((n_times, n_labels), dtype=bool)
for i, (label, filename) in enumerate(self.meta_labels.items()):
if self.cache_data:
intervals = self.labeled_intervals[label]
else:
intervals = np.load(self.root_folder / filename, allow_pickle=True)
if len(intervals) == 0:
warnings.warn(
f"TimeIntervalInterpolator found no intervals for label: {label}"
)
continue
for start, end in intervals:
if start > end:
warnings.warn(
f"Invalid interval found for label: {label}, interval: ({start}, {end})"
)
continue
# Half-open interval [start, end): inclusive start, exclusive end
mask = (valid_times >= start) & (valid_times < end)
out[mask, i] = True
return out
experanto/interpolators.py:824
- This call to ScreenTrial.get_data_ in an initialization method is overridden by EncodedImageTrial.get_data_.
This call to ScreenTrial.get_data_ in an initialization method is overridden by EncodedVideoTrial.get_data_.
This call to ScreenTrial.get_data_ in an initialization method is overridden by BlankTrial.get_data_.
This call to ScreenTrial.get_data_ in an initialization method is overridden by InvalidTrial.get_data_.
self._cached_data = self.get_data_()
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
cebd565 to
16a2a14
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 8 out of 10 changed files in this pull request and generated 7 comments.
Comments suppressed due to low confidence (2)
experanto/interpolators.py:802
- The TimeIntervalInterpolator class is defined twice in this file. The second definition (lines 714-802) is a complete duplicate of the first one (lines 623-711). This will cause the second definition to override the first, but having duplicate code is confusing and should be removed. Only one definition should remain.
class TimeIntervalInterpolator(Interpolator):
def __init__(self, root_folder: str, cache_data: bool = False, **kwargs):
super().__init__(root_folder)
self.cache_data = cache_data
meta = self.load_meta()
self.meta_labels = meta["labels"]
self.start_time = meta["start_time"]
self.end_time = meta["end_time"]
self.valid_interval = TimeInterval(self.start_time, self.end_time)
if self.cache_data:
self.labeled_intervals = {
label: np.load(self.root_folder / filename)
for label, filename in self.meta_labels.items()
}
def interpolate(self, times: np.ndarray) -> np.ndarray:
"""
Interpolate time intervals for labeled events.
Given a set of time points and a set of labeled intervals (defined in the
`meta.yml` file), this method returns a boolean array indicating, for each
time point, whether it falls within any interval for each label.
The method uses half-open intervals [start, end), where a timestamp t is
considered to fall within an interval if start <= t < end. This means the
start time is inclusive and the end time is exclusive.
Parameters
----------
times : np.ndarray
Array of time points to be checked against the labeled intervals.
Returns
-------
out : np.ndarray of bool, shape (len(valid_times), n_labels)
Boolean array where each row corresponds to a valid time point and each
column corresponds to a label. `out[i, j]` is True if the i-th valid
time falls within any interval for the j-th label, and False otherwise.
Notes
-----
- The labels and their corresponding intervals are defined in the `meta.yml`
file under the `labels` key. Each label points to a `.npy` file containing
an array of shape (n, 2), where each row is a [start, end) time interval.
- Typical labels might include 'train', 'validation', 'test', 'saccade',
'gaze', or 'target'.
- Only time points within the valid interval (as defined by start_time and
end_time in meta.yml) are considered; others are filtered out.
- Intervals where start > end are considered invalid and will trigger a
warning.
"""
valid = self.valid_times(times)
valid_times = times[valid]
n_labels = len(self.meta_labels)
n_times = len(valid_times)
if n_times == 0:
warnings.warn(
"TimeIntervalInterpolator returns an empty array, no valid times queried."
)
return np.empty((0, n_labels), dtype=bool)
out = np.zeros((n_times, n_labels), dtype=bool)
for i, (label, filename) in enumerate(self.meta_labels.items()):
if self.cache_data:
intervals = self.labeled_intervals[label]
else:
intervals = np.load(self.root_folder / filename, allow_pickle=True)
if len(intervals) == 0:
warnings.warn(
f"TimeIntervalInterpolator found no intervals for label: {label}"
)
continue
for start, end in intervals:
if start > end:
warnings.warn(
f"Invalid interval found for label: {label}, interval: ({start}, {end})"
)
continue
# Half-open interval [start, end): inclusive start, exclusive end
mask = (valid_times >= start) & (valid_times < end)
out[mask, i] = True
return out
experanto/interpolators.py:824
- This call to ScreenTrial.get_data_ in an initialization method is overridden by EncodedImageTrial.get_data_.
This call to ScreenTrial.get_data_ in an initialization method is overridden by EncodedVideoTrial.get_data_.
This call to ScreenTrial.get_data_ in an initialization method is overridden by BlankTrial.get_data_.
This call to ScreenTrial.get_data_ in an initialization method is overridden by InvalidTrial.get_data_.
self._cached_data = self.get_data_()
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…before formatting of shapes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.
3c5bb5a to
b3b0079
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 8 out of 10 changed files in this pull request and generated 16 comments.
Comments suppressed due to low confidence (1)
experanto/interpolators.py:826
- This call to ScreenTrial.get_data_ in an initialization method is overridden by EncodedImageTrial.get_data_.
This call to ScreenTrial.get_data_ in an initialization method is overridden by EncodedVideoTrial.get_data_.
This call to ScreenTrial.get_data_ in an initialization method is overridden by BlankTrial.get_data_.
This call to ScreenTrial.get_data_ in an initialization method is overridden by InvalidTrial.get_data_.
self._cached_data = self.get_data_()
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
experanto/interpolators.py
Outdated
|
|
||
| def get_data_(self) -> np.array: | ||
| """Override base implementation to load compressed images""" | ||
| img = cv2.imread(self.data_file_name) |
Copilot
AI
Jan 27, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cv2.imread requires a string path, but data_file_name might be a Path object (as shown in line 448 where Path objects are created). This could cause cv2.imread to fail. Convert to string: img = cv2.imread(str(self.data_file_name)).
| img = cv2.imread(self.data_file_name) | |
| img = cv2.imread(str(self.data_file_name)) |
Implemented support for compressed screen formats
get_data_()methods that support encoded formats.Right now the VideoDecoder decodes the entire video and returns that to the interpolation function. This can be optimized with the help of sliced decoding which would require changes to the interpolation method. @pollytur and I decided we should discuss first and implement it later on if deemed necessary.