diff --git a/.gitignore b/.gitignore index 9c4dda937c..6a7edf06f8 100644 --- a/.gitignore +++ b/.gitignore @@ -116,6 +116,7 @@ examples/tutorials/*.svg doc/_build/* doc/tutorials/* +doc/forhowto/* doc/sources/* *sg_execution_times.rst diff --git a/doc/conf.py b/doc/conf.py index b4ff6e97fe..1b46c964dd 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -38,6 +38,7 @@ '../examples/tutorials/metrics/curated_sorting', '../examples/tutorials/metrics/clean_analyzer.zarr', '../examples/tutorials/widgets/waveforms_mearec', + '../examples/forhowto/cached' ] @@ -123,8 +124,8 @@ # This is the default but including here explicitly. Should build all docs and fail on gallery failures only. # other option would be abort_on_example_error, but this fails on first failure. So we decided against this. 'only_warn_on_example_error': False, - 'examples_dirs': ['../examples/tutorials'], - 'gallery_dirs': ['tutorials' ], # path where to save gallery generated examples + 'examples_dirs': ['../examples/tutorials', '../examples/forhowto'], + 'gallery_dirs': ['tutorials', 'forhowto'], # path where to save gallery generated examples 'subsection_order': ExplicitOrder([ '../examples/tutorials/core', '../examples/tutorials/extractors', @@ -132,7 +133,7 @@ '../examples/tutorials/metrics', '../examples/tutorials/comparison', '../examples/tutorials/widgets', - '../examples/tutorials/forhowto', + '../examples/forhowto', ]), 'within_subsection_order': FileNameSortKey, 'ignore_pattern': '/generate_*', diff --git a/doc/how_to/index.rst b/doc/how_to/index.rst index db02da5cef..6f280f888f 100644 --- a/doc/how_to/index.rst +++ b/doc/how_to/index.rst @@ -10,10 +10,11 @@ Guides on how to solve specific, short problems in SpikeInterface. Learn how to. customize_a_plot combine_recordings process_by_channel_group + ../forhowto/plot_extract_lfps build_pipeline_with_dicts physical_units unsigned_to_signed - ../tutorials/forhowto/plot_1_working_with_tetrodes + ../forhowto/plot_working_with_tetrodes analyze_neuropixels handle_drift drift_with_lfp diff --git a/examples/forhowto/README.rst b/examples/forhowto/README.rst new file mode 100644 index 0000000000..09ef2f2b24 --- /dev/null +++ b/examples/forhowto/README.rst @@ -0,0 +1,9 @@ +For "How to" +============ + +These python scripts are files which we would like sphinx-gallery to run, +but are linked by the how to guide. + +All the python files in this folder are built by sphinx-gallery. The resulting +documentation page can be added to the "How to" section of the documentation by modifying +the file ``doc/how_to/index.rst``. diff --git a/examples/forhowto/plot_extract_lfps.py b/examples/forhowto/plot_extract_lfps.py new file mode 100644 index 0000000000..53c4c256c2 --- /dev/null +++ b/examples/forhowto/plot_extract_lfps.py @@ -0,0 +1,333 @@ +""" +Extract LFPs +============ + +Understanding filtering artifacts and chunking when extracting LFPs +------------------------------------------------------------------- + +Local Field Potentials (LFPs) are low-frequency signals (<300 Hz) that reflect the summed activity of many neurons. +Extracting LFPs from high-sampling-rate recordings requires bandpass filtering, but this can introduce artifacts +when not done carefully, especially when data is processed in chunks (which is usually the required because datasets +cannot be loaded entirely into memory). + +Before we get started, let's introduce some important concepts: + +Chunk +~~~~~ + +A "chunk" is a piece of recording that gets processed in parallel by SpikeInterface. +The default chunk duration for most operations is 1 second, but we'll see how this is not adequate for LFP +processing. + + +Margin +~~~~~~ + +When we apply a filter on chunked data, we extract additional "margins" of traces at the chunk borders. +This is done to reduce border artifacts. + + +This tutorial demonstrates: + +1. How to generate simulated LFP data +2. Common pitfalls when filtering with low cutoff frequencies +3. How chunking and margins affect filtering artifacts +4. Summary + +**Key takeaway**: For LFP extraction, use large chunks (30-60s) and large margins (several seconds) to minimize +edge artifacts, even though this is less memory-efficient. +""" + +############################################################################## +# Import necessary modules + +import time +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path +import pandas as pd +import seaborn as sns + +import spikeinterface as si +import spikeinterface.extractors as se +import spikeinterface.preprocessing as spre +import spikeinterface.widgets as sw +from spikeinterface.core import generate_ground_truth_recording + +############################################################################## +# 1. Generate simulated recording with low-frequency signals +# ----------------------------------------------------------- +# +# Let's create a simulated recording and add some low-frequency sinusoids that mimic LFP activity. + +# Generate a ground truth recording with spikes +# Use a higher sampling rate (30 kHz) to simulate raw neural data +recording, sorting = generate_ground_truth_recording( + durations=[60.0], + sampling_frequency=30000.0, + num_channels=1, + num_units=4, + seed=2305, +) + +print(recording) + +############################################################################## +# Now let's add some low-frequency sinusoidal components to simulate LFP signals + +# Add low-frequency sinusoids with different frequencies and phases per channel +rng = np.random.default_rng(42) +num_channels = recording.get_num_channels() +lfp_signals = np.zeros( + (recording.get_num_samples(), recording.get_num_channels()) +) +time_vector = recording.get_times() + +for ch in range(num_channels): + # Add multiple frequency components (theta, alpha, beta ranges) + # Theta-like: 4-8 Hz + freq_theta = 4 + rng.random() * 4 + phase_theta = rng.random() * 2 * np.pi + amp_theta = 50 + rng.random() * 50 + + # Alpha-like: 8-12 Hz + freq_alpha = 8 + rng.random() * 4 + phase_alpha = rng.random() * 2 * np.pi + amp_alpha = 30 + rng.random() * 30 + + # Beta-like: 12-30 Hz + freq_beta = 12 + rng.random() * 18 + phase_beta = rng.random() * 2 * np.pi + amp_beta = 20 + rng.random() * 20 + + lfp_signals[:, ch] = ( + amp_theta * np.sin(2 * np.pi * freq_theta * time_vector + phase_theta) + + amp_alpha * np.sin(2 * np.pi * freq_alpha * time_vector + phase_alpha) + + amp_beta * np.sin(2 * np.pi * freq_beta * time_vector + phase_beta) + ) + +# Create a recording with the added LFP signals +recording_lfp = si.NumpyRecording( + traces_list=[lfp_signals], + sampling_frequency=recording.sampling_frequency, + channel_ids=recording.channel_ids, +) +recording_with_lfp = recording + recording_lfp + + +############################################################################## +# Let's visualize a short segment of the signal + +_ = sw.plot_traces(recording_with_lfp, time_range=[0, 3]) + +############################################################################## +# 2. Filtering with low cutoff frequencies: the problem +# ------------------------------------------------------ +# +# Now let's try to extract LFPs using a bandpass filter with a low highpass cutoff (1 Hz). +# This will demonstrate a common issue. + +# Try to filter with 1 Hz highpass +try: + recording_lfp_1hz = spre.bandpass_filter( + recording_with_lfp, freq_min=1.0, freq_max=300.0 + ) +except Exception as e: + print(f"Error message:\n{str(e)}") + +############################################################################## +# **Why does this fail?** +# +# The error always occurs in SpikeInterface when highpass filtering below 100 Hz, to remind the user that they need to be careful. +# Filters with very low cutoff frequencies have long impulse responses, which require larger margins to avoid edge artifacts between chunks. +# +# The filter length (and required margin) scales inversely with the highpass frequency. A 1 Hz highpass +# filter requires a margin of several seconds, while a 300 Hz highpass (for spike extraction) only needs +# a few milliseconds. +# +# **This error is to inform the user that extra care should be used when dealing with LFP signals!** + + +############################################################################## +# 3. Understanding chunking and margins +# -------------------------------------- +# +# SpikeInterface processes recordings in chunks to handle large datasets efficiently. Each chunk needs +# a "margin" (extra samples at the edges) to avoid edge artifacts when filtering. Let's demonstrate +# this by saving the filtered data with different chunking strategies. +# +# We can explicitly ignore the previous error, but let's make sure we understand what is happening. + +recording_filt = spre.bandpass_filter( + recording_with_lfp, freq_min=1.0, freq_max=300.0, ignore_low_freq_error=True +) + +############################################################################## +# When retrieving traces, extra samples will be retrieved at the left and right edges. +# By default, the filter function will set a margin to 5x the sampling period associated to `freq_min`. +# So for a 1 Hz cutoff frequency, the margin will be 5 seconds! + +margin_in_s = recording_filt.margin_samples / recording_lfp.sampling_frequency +print(f"Margin: {margin_in_s} s") + +############################################################################## +# This effectively means that if we plot 1-s snippet of traces, a total of 11 s will actually be read and filtered. +# Hence the computational "overhead" is very large. +# Note that the margin can be overridden with the `margin_ms` argument, but we do not recommend changing it. + +_ = sw.plot_traces(recording_filt, time_range=[20, 21]) + +############################################################################## +# A warning tells us that what we are doing is not optimized, since in order to get the requested traces +# the margin "overhead" is very large. +# +# If we ask or plot longer snippets, the warning is not displayed. + +_ = sw.plot_traces(recording_filt, time_range=[20, 80]) + +############################################################################## +# 4. Quantification and visualization of the artifacts +# ----------------------------------------------------- +# +# Let's extract the traces and visualize the differences between chunking strategies. +# We'll focus on the chunk boundaries where artifacts appear. + +margins_ms = [100, 1000, 5000] +chunk_durations = ["1s", "10s", "30s"] + +############################################################################## +# The best we can do is to save the full recording in one chunk. This will cause no artifacts and chunking effects, +# but in practice it's not possible due to the duration and number of channels of most setups. +# +# Since in this toy case we have a single channel 5-min recording, we can use this as "optimal". + +recording_optimal = recording_filt.save( + folder="./cached/optimal", + chunk_duration="1000s", + progress_bar=False +) + +print(recording_optimal) + +############################################################################## +# Now we can do the same with our various options: + +recordings_chunked = {} + +for margin_ms in margins_ms: + for chunk_duration in chunk_durations: + print(f"Margin ms: {margin_ms} - Chunk duration: {chunk_duration}") + t_start = time.perf_counter() + recording_chunk = spre.bandpass_filter( + recording_with_lfp, + freq_min=1.0, + freq_max=300.0, + margin_ms=margin_ms, + ignore_low_freq_error=True, + ) + recording_chunk = recording_chunk.save( + folder=f"./cached/{margin_ms}_{chunk_duration}", + chunk_duration=chunk_duration, + verbose=False, + progress_bar=False + ) + t_stop = time.perf_counter() + result_dict = {"recording": recording_chunk, "time": t_stop - t_start} + recordings_chunked[(margin_ms, chunk_duration)] = result_dict + +############################################################################## +# Let's visualize the error for the "10s" chunks and different margins, centered around 30s (which is a chunk edge): + +fig, ax = plt.subplots(figsize=(10, 5)) +trace_plotted = False +start_time = 15 # seconds +end_time = 45 # seconds +start_frame = int(start_time * recording_optimal.sampling_frequency) +end_frame = int(end_time * recording_optimal.sampling_frequency) +timestamps = recording_optimal.get_times()[start_frame:end_frame] +for recording_key, recording_dict in recordings_chunked.items(): + recording_chunk = recording_dict["recording"] + margin, chunk = recording_key + # only plot "10s" chunks + if chunk != "10s": + continue + traces_opt = recording_optimal.get_traces( + start_frame=start_frame, end_frame=end_frame + ) + if not trace_plotted: + ax.plot(timestamps, traces_opt, color="grey", label="traces", alpha=0.5) + trace_plotted = True + diff = recording_optimal - recording_chunk + traces_diff = diff.get_traces(start_frame=start_frame, end_frame=end_frame) + ax.plot(timestamps, traces_diff, label=f"Margin: {margin}") + for chunk in [20, 30, 40]: # chunk boundaries at 10s intervals + ax.axvline(x=chunk, color="red", linestyle="--", alpha=0.5) + +ax.set_xlabel("Time (s)") +ax.set_ylabel("Voltage ($\\mu V$)") +_ = ax.legend() + +############################################################################## +# From the plot, we can see that there is a very small error when the margin size is large (green), +# a larger error when the margin is smaller (orange) and a large error when the margin is small (blue). +# So we need large margins (compared to the chunk size) if we want accurate filtered. +# +# The artifacts do not depend on chunk size, but for smaller chunk sizes, these artifacts will happen more often. +# In addition, the margin "overhead" will make processing slower. Let's quantify these concepts by computing the +# overall absolute error with respect to the optimal case and processing time. + +trace_plotted = False +traces_optimal = recording_optimal.get_traces() +data = {"margin": [], "chunk": [], "error": [], "time": []} +for recording_key, recording_dict in recordings_chunked.items(): + recording_chunk = recording_dict["recording"] + time = recording_dict["time"] + margin, chunk = recording_key + traces_chunk = recording_chunk.get_traces() + error = np.sum(np.abs(traces_optimal - traces_chunk)) + data["margin"].append(margin) + data["chunk"].append(chunk) + data["error"].append(error) + data["time"].append(time) + +df = pd.DataFrame(data=data) + +############################################################################## +# Now let's visualize the error and processing time for different margin and chunk size combinations + +fig, axs = plt.subplots(ncols=2, figsize=(10, 5)) +sns.barplot(data=data, x="margin", y="error", hue="chunk", ax=axs[0]) +axs[0].set_yscale("log") +sns.barplot(data=data, x="margin", y="time", hue="chunk", ax=axs[1]) +axs[0].set_title("Error VS margin x chunk size") +axs[1].set_title("Processing time VS margin x chunk size") + +sns.despine(fig) + +############################################################################## +# Summary +# ------- +# +# 1. **Low-frequency filters require special care**: Filters with low cutoff frequencies (< 10 Hz) have long +# impulse responses that require large margins to avoid edge artifacts. +# +# 2. **Chunking artifacts are real**: When processing data in chunks, insufficient margins lead to visible +# discontinuities and errors at chunk boundaries. +# +# 3. **The solution: large chunks and large margins**: For LFP extraction (1-300 Hz), use: +# - Chunk size: 30-60 seconds +# - Margin size: 5 seconds (for 1 Hz highpass) (**use defaults!**) +# - This is less memory-efficient but more accurate +# +# 4. **Downsample after filtering**: After bandpass filtering, downsample to reduce data size (e.g., to 1-2.5 kHz +# for 300 Hz max frequency). +# +# 5. **Trade-offs**: There's always a trade-off between computational efficiency (smaller chunks, less memory) +# and accuracy (larger chunks, fewer artifacts). For LFP analysis, accuracy should take priority. +# +# **When processing your own data:** +# +# - If you have memory constraints, use the largest chunk size your system can handle +# - Always verify your filtering parameters on a small test segment first +# - Consider the lowest frequency component you want to preserve when setting margins +# - Save the processed LFP data to disk to avoid recomputing diff --git a/examples/tutorials/forhowto/plot_1_working_with_tetrodes.py b/examples/forhowto/plot_working_with_tetrodes.py similarity index 100% rename from examples/tutorials/forhowto/plot_1_working_with_tetrodes.py rename to examples/forhowto/plot_working_with_tetrodes.py diff --git a/examples/tutorials/forhowto/README.rst b/examples/tutorials/forhowto/README.rst deleted file mode 100644 index 34805dba84..0000000000 --- a/examples/tutorials/forhowto/README.rst +++ /dev/null @@ -1,4 +0,0 @@ -For how to ----------- - -These documents are files which we would like sphinx-gallery to run, but are linked by the how to guide. diff --git a/pyproject.toml b/pyproject.toml index 44601850e3..6cc950d4df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -208,6 +208,7 @@ docs = [ "hdbscan>=0.8.33", # For sorters spykingcircus2 + tridesclous "numba", # For many postprocessing functions "networkx", + "seaborn", "skops", # For automated curation "scikit-learn<1.8", # For automated curation "huggingface_hub", # For automated curation diff --git a/src/spikeinterface/core/recording_tools.py b/src/spikeinterface/core/recording_tools.py index fd95b11e6a..992a48c589 100644 --- a/src/spikeinterface/core/recording_tools.py +++ b/src/spikeinterface/core/recording_tools.py @@ -214,13 +214,7 @@ def write_binary_recording_file_handle( def _init_memory_worker(recording, arrays, shm_names, shapes, dtype): # create a local dict per worker worker_ctx = {} - if isinstance(recording, dict): - from spikeinterface.core import load - - worker_ctx["recording"] = load(recording) - else: - worker_ctx["recording"] = recording - + worker_ctx["recording"] = recording worker_ctx["dtype"] = np.dtype(dtype) if arrays is None: diff --git a/src/spikeinterface/exporters/tests/test_export_to_ibl.py b/src/spikeinterface/exporters/tests/test_export_to_ibl.py index 3b859634df..96038ced59 100644 --- a/src/spikeinterface/exporters/tests/test_export_to_ibl.py +++ b/src/spikeinterface/exporters/tests/test_export_to_ibl.py @@ -78,7 +78,7 @@ def test_export_lfp_to_ibl(sorting_analyzer_sparse_for_export, create_cache_fold sorting_analyzer = sorting_analyzer_sparse_for_export recording = sorting_analyzer.recording - recording_lfp = bandpass_filter(recording, freq_min=0.5, freq_max=300) + recording_lfp = bandpass_filter(recording, freq_min=0.5, freq_max=300, ignore_low_freq_error=True) recording_lfp = decimate(recording_lfp, 10) # LFP, but no AP export_to_ibl_gui( diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 78542e1f37..732b310123 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -1,4 +1,5 @@ from __future__ import annotations +import warnings import numpy as np @@ -8,6 +9,10 @@ from spikeinterface.core import get_chunk_with_margin +HIGHPASS_ERROR_THRESHOLD_HZ = 100 +MARGIN_TO_CHUNK_PERCENT_WARNING = 0.2 # 20% + + _common_filter_docs = """**filter_kwargs : dict Certain keyword arguments for `scipy.signal` filters: filter_order : order @@ -42,8 +47,9 @@ class FilterRecording(BasePreprocessor): If list. band (low, high) in Hz for "bandpass" filter type btype : "bandpass" | "highpass", default: "bandpass" Type of the filter - margin_ms : float, default: 5.0 - Margin in ms on border to avoid border effect + margin_ms : float, default: None + Margin in ms on border to avoid border effect. + Must be provided by sub-class. coeff : array | None, default: None Filter coefficients in the filter_mode form. dtype : dtype or None, default: None @@ -78,7 +84,7 @@ def __init__( filter_order=5, ftype="butter", filter_mode="sos", - margin_ms=5.0, + margin_ms=None, add_reflect_padding=False, coeff=None, dtype=None, @@ -110,7 +116,9 @@ def __init__( if "offset_to_uV" in self.get_property_keys(): self.set_channel_offsets(0) + assert margin_ms is not None, "margin_ms must be provided!" margin = int(margin_ms * fs / 1000.0) + self.margin_samples = margin for parent_segment in recording._recording_segments: self.add_recording_segment( FilterRecordingSegment( @@ -159,6 +167,12 @@ def __init__( self.dtype = dtype def get_traces(self, start_frame, end_frame, channel_indices): + if self.margin > MARGIN_TO_CHUNK_PERCENT_WARNING * (end_frame - start_frame): + warnings.warn( + f"The margin size ({self.margin} samples) is more than {int(MARGIN_TO_CHUNK_PERCENT_WARNING * 100)}% " + f"of the chunk size {(end_frame - start_frame)} samples. This may lead to performance bottlenecks when " + f"chunking. Consider increasing the chunk size to minimize margin overhead." + ) traces_chunk, left_margin, right_margin = get_chunk_with_margin( self.parent_recording_segment, start_frame, @@ -217,10 +231,13 @@ class BandpassFilterRecording(FilterRecording): The highpass cutoff frequency in Hz freq_max : float The lowpass cutoff frequency in Hz - margin_ms : float - Margin in ms on border to avoid border effect + margin_ms : float | str, default: "auto" + Margin in ms on border to avoid border effect. + If "auto", margin is computed as 3 times the filter highpass cutoff period. dtype : dtype or None The dtype of the returned traces. If None, the dtype of the parent recording is used + ignore_low_freq_error : bool, default: False + If True, does not raise an error if freq_min is too low for the sampling frequency. {} Returns @@ -229,13 +246,30 @@ class BandpassFilterRecording(FilterRecording): The bandpass-filtered recording extractor object """ - def __init__(self, recording, freq_min=300.0, freq_max=6000.0, margin_ms=5.0, dtype=None, **filter_kwargs): + def __init__( + self, + recording, + freq_min=300.0, + freq_max=6000.0, + margin_ms="auto", + dtype=None, + ignore_low_freq_error=False, + **filter_kwargs, + ): + if margin_ms == "auto": + margin_ms = adjust_margin_ms_for_highpass(freq_min) + highpass_check(freq_min, margin_ms, ignore_low_freq_error=ignore_low_freq_error) FilterRecording.__init__( self, recording, band=[freq_min, freq_max], margin_ms=margin_ms, dtype=dtype, **filter_kwargs ) dtype = fix_dtype(recording, dtype) self._kwargs = dict( - recording=recording, freq_min=freq_min, freq_max=freq_max, margin_ms=margin_ms, dtype=dtype.str + recording=recording, + freq_min=freq_min, + freq_max=freq_max, + margin_ms=margin_ms, + dtype=dtype.str, + ignore_low_freq_error=ignore_low_freq_error, ) self._kwargs.update(filter_kwargs) @@ -250,10 +284,13 @@ class HighpassFilterRecording(FilterRecording): The recording extractor to be re-referenced freq_min : float The highpass cutoff frequency in Hz - margin_ms : float - Margin in ms on border to avoid border effect + margin_ms : float | str, default: "auto" + Margin in ms on border to avoid border effect. + If "auto", margin is computed as 3 times the filter highpass cutoff period. dtype : dtype or None The dtype of the returned traces. If None, the dtype of the parent recording is used + ignore_low_freq_error : bool, default: False + If True, does not raise an error if freq_min is too low for the sampling frequency. {} Returns @@ -262,7 +299,12 @@ class HighpassFilterRecording(FilterRecording): The highpass-filtered recording extractor object """ - def __init__(self, recording, freq_min=300.0, margin_ms=5.0, dtype=None, **filter_kwargs): + def __init__( + self, recording, freq_min=300.0, margin_ms="auto", dtype=None, ignore_low_freq_error=False, **filter_kwargs + ): + if margin_ms == "auto": + margin_ms = adjust_margin_ms_for_highpass(freq_min) + highpass_check(freq_min, margin_ms, ignore_low_freq_error=ignore_low_freq_error) FilterRecording.__init__( self, recording, band=freq_min, margin_ms=margin_ms, dtype=dtype, btype="highpass", **filter_kwargs ) @@ -271,7 +313,7 @@ def __init__(self, recording, freq_min=300.0, margin_ms=5.0, dtype=None, **filte self._kwargs.update(filter_kwargs) -class NotchFilterRecording(BasePreprocessor): +class NotchFilterRecording(FilterRecording): """ Parameters ---------- @@ -283,7 +325,7 @@ class NotchFilterRecording(BasePreprocessor): The quality factor of the notch filter dtype : None | dtype, default: None dtype of recording. If None, will take from `recording` - margin_ms : float, default: 5.0 + margin_ms : float | str, default: "auto" Margin in ms on border to avoid border effect Returns @@ -292,16 +334,16 @@ class NotchFilterRecording(BasePreprocessor): The notch-filtered recording extractor object """ - def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): - # coeef is 'ba' type - fn = 0.5 * float(recording.get_sampling_frequency()) + def __init__(self, recording, freq=3000, q=30, margin_ms="auto", dtype=None, **filter_kwargs): import scipy.signal + if margin_ms == "auto": + margin_ms = adjust_margin_ms_for_notch(q, freq) + + fn = 0.5 * float(recording.get_sampling_frequency()) coeff = scipy.signal.iirnotch(freq / fn, q) - if dtype is None: - dtype = recording.get_dtype() - dtype = np.dtype(dtype) + dtype = fix_dtype(recording, dtype) # if uint --> unsupported if dtype.kind == "u": @@ -310,15 +352,12 @@ def __init__(self, recording, freq=3000, q=30, margin_ms=5.0, dtype=None): "to specify a signed type (e.g. 'int16', 'float32')" ) - BasePreprocessor.__init__(self, recording, dtype=dtype) + FilterRecording.__init__( + self, recording, coeff=coeff, filter_mode="ba", margin_ms=margin_ms, dtype=dtype, **filter_kwargs + ) self.annotate(is_filtered=True) - - sf = recording.get_sampling_frequency() - margin = int(margin_ms * sf / 1000.0) - for parent_segment in recording._recording_segments: - self.add_recording_segment(FilterRecordingSegment(parent_segment, coeff, "ba", margin, dtype)) - self._kwargs = dict(recording=recording, freq=freq, q=q, margin_ms=margin_ms, dtype=dtype.str) + self._kwargs.update(filter_kwargs) # functions for API @@ -398,6 +437,38 @@ def causal_filter( highpass_filter.__doc__ = highpass_filter.__doc__.format(_common_filter_docs) +def adjust_margin_ms_for_highpass(freq_min, multiplier=5): + margin_ms = multiplier * (1000.0 / freq_min) + return margin_ms + + +def adjust_margin_ms_for_notch(q, f0, multiplier=5): + margin_ms = (multiplier / np.pi) * (q / f0) * 1000.0 + return margin_ms + + +def highpass_check(freq_min, margin_ms, ignore_low_freq_error=False): + if freq_min < HIGHPASS_ERROR_THRESHOLD_HZ: + if not ignore_low_freq_error: + raise ValueError( + f"The freq_min ({freq_min} Hz) is too low and may cause artifacts during chunk processing. " + f"You can set 'ignore_low_freq_error=True' to bypass this error, but make sure you understand the implications. " + f"It is recommended to use large chunks when processing/saving your filtered recording to minimize IO overhead." + f"Refer to this documentation on LFP filtering and chunking artifacts for more details: " + f"https://spikeinterface.readthedocs.io/en/latest/how-to/extract_lfps.html. " + ) + if margin_ms == "auto": + margin_ms = adjust_margin_ms_for_highpass(freq_min) + else: + auto_margin_ms = adjust_margin_ms_for_highpass(freq_min) + if margin_ms < auto_margin_ms: + warnings.warn( + f"The provided margin_ms ({margin_ms} ms) is smaller than the recommended margin for the given freq_min ({freq_min} Hz). " + f"This may lead to artifacts at the edges of chunks during processing. " + f"Consider increasing the margin_ms to at least {auto_margin_ms} ms." + ) + + def fix_dtype(recording, dtype): """ Fix recording dtype for preprocessing, by always returning a numpy.dtype.