Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,12 +1137,16 @@ def _load_extractor_from_dict(dic) -> BaseExtractor:
extractor_class = _get_class_from_string(class_name)

assert extractor_class is not None and class_name is not None, "Could not load spikeinterface class"
if not _check_same_version(class_name, dic["version"]):
is_old_version = not _check_same_version(class_name, dic["version"])
if is_old_version:
warnings.warn(
f"Versions are not the same. This might lead to compatibility errors. "
f"Using {class_name.split('.')[0]}=={dic['version']} is recommended"
)

if hasattr(extractor_class, "_handle_backward_compatibility"):
new_kwargs = extractor_class._handle_backward_compatibility(new_kwargs, dic)

# Initialize the extractor
extractor = extractor_class(**new_kwargs)

Expand Down
70 changes: 57 additions & 13 deletions src/spikeinterface/preprocessing/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import numpy as np

from spikeinterface.core.core_tools import define_function_handling_dict_from_class
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment
from spikeinterface.core import get_chunk_with_margin, ensure_chunk_size, get_global_job_kwargs

from spikeinterface.core import get_chunk_with_margin
from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment

HIGHPASS_ERROR_THRESHOLD_HZ = 100
MARGIN_TO_CHUNK_PERCENT_WARNING = 0.2 # 20%
Expand Down Expand Up @@ -117,6 +117,14 @@ def __init__(

assert margin_ms is not None, "margin_ms must be provided!"
margin = int(margin_ms * fs / 1000.0)

global_job_kwargs_chunk_size = ensure_chunk_size(recording, **get_global_job_kwargs())
if margin > MARGIN_TO_CHUNK_PERCENT_WARNING * global_job_kwargs_chunk_size:
warnings.warn(
f"The margin size ({margin} samples) is more than {int(MARGIN_TO_CHUNK_PERCENT_WARNING * 100)}% "
f"of the global chunk size {global_job_kwargs_chunk_size} samples. This may lead to performance bottlenecks when "
f"chunking. Consider increasing the chunk_size or chunk_duration to minimize margin overhead."
)
self.margin_samples = margin
for parent_segment in recording._recording_segments:
self.add_recording_segment(
Expand Down Expand Up @@ -166,12 +174,6 @@ def __init__(
self.dtype = dtype

def get_traces(self, start_frame, end_frame, channel_indices):
if self.margin > MARGIN_TO_CHUNK_PERCENT_WARNING * (end_frame - start_frame):
warnings.warn(
f"The margin size ({self.margin} samples) is more than {int(MARGIN_TO_CHUNK_PERCENT_WARNING * 100)}% "
f"of the chunk size {(end_frame - start_frame)} samples. This may lead to performance bottlenecks when "
f"chunking. Consider increasing the chunk size to minimize margin overhead."
)
traces_chunk, left_margin, right_margin = get_chunk_with_margin(
self.parent_recording_segment,
start_frame,
Expand Down Expand Up @@ -253,11 +255,17 @@ def __init__(
margin_ms="auto",
dtype=None,
ignore_low_freq_error=False,
_skip_margin_warning_for_old_version=False,
**filter_kwargs,
):
if margin_ms == "auto":
margin_ms = adjust_margin_ms_for_highpass(freq_min)
highpass_check(freq_min, margin_ms, ignore_low_freq_error=ignore_low_freq_error)
highpass_check(
freq_min,
margin_ms,
ignore_low_freq_error=ignore_low_freq_error,
skip_warning=_skip_margin_warning_for_old_version,
)
FilterRecording.__init__(
self, recording, band=[freq_min, freq_max], margin_ms=margin_ms, dtype=dtype, **filter_kwargs
)
Expand All @@ -272,6 +280,18 @@ def __init__(
)
self._kwargs.update(filter_kwargs)

@classmethod
def _handle_backward_compatibility(cls, old_kwargs, full_dict):
new_kwargs = old_kwargs.copy()
is_lfp_case = old_kwargs["freq_min"] < HIGHPASS_ERROR_THRESHOLD_HZ
if "ignore_low_freq_error" not in new_kwargs:
new_kwargs["ignore_low_freq_error"] = True
if is_lfp_case:
new_kwargs["_skip_margin_warning_for_old_version"] = False
else:
new_kwargs["_skip_margin_warning_for_old_version"] = True
return new_kwargs


class HighpassFilterRecording(FilterRecording):
"""
Expand Down Expand Up @@ -299,18 +319,42 @@ class HighpassFilterRecording(FilterRecording):
"""

def __init__(
self, recording, freq_min=300.0, margin_ms="auto", dtype=None, ignore_low_freq_error=False, **filter_kwargs
self,
recording,
freq_min=300.0,
margin_ms="auto",
dtype=None,
ignore_low_freq_error=False,
_skip_margin_warning_for_old_version=False,
**filter_kwargs,
):
if margin_ms == "auto":
margin_ms = adjust_margin_ms_for_highpass(freq_min)
highpass_check(freq_min, margin_ms, ignore_low_freq_error=ignore_low_freq_error)
highpass_check(
freq_min,
margin_ms,
ignore_low_freq_error=ignore_low_freq_error,
skip_warning=_skip_margin_warning_for_old_version,
)
FilterRecording.__init__(
self, recording, band=freq_min, margin_ms=margin_ms, dtype=dtype, btype="highpass", **filter_kwargs
)
dtype = fix_dtype(recording, dtype)
self._kwargs = dict(recording=recording, freq_min=freq_min, margin_ms=margin_ms, dtype=dtype.str)
self._kwargs.update(filter_kwargs)

@classmethod
def _handle_backward_compatibility(cls, old_kwargs, full_dict):
new_kwargs = old_kwargs.copy()
is_lfp_case = old_kwargs["freq_min"] < HIGHPASS_ERROR_THRESHOLD_HZ
if "ignore_low_freq_error" not in new_kwargs:
new_kwargs["ignore_low_freq_error"] = True
if is_lfp_case:
new_kwargs["_skip_margin_warning_for_old_version"] = False
else:
new_kwargs["_skip_margin_warning_for_old_version"] = True
return new_kwargs


class NotchFilterRecording(FilterRecording):
"""
Expand Down Expand Up @@ -446,7 +490,7 @@ def adjust_margin_ms_for_notch(q, f0, multiplier=5):
return margin_ms


def highpass_check(freq_min, margin_ms, ignore_low_freq_error=False):
def highpass_check(freq_min, margin_ms, ignore_low_freq_error=False, skip_warning=False):
if freq_min < HIGHPASS_ERROR_THRESHOLD_HZ:
if not ignore_low_freq_error:
raise ValueError(
Expand All @@ -460,7 +504,7 @@ def highpass_check(freq_min, margin_ms, ignore_low_freq_error=False):
margin_ms = adjust_margin_ms_for_highpass(freq_min)
else:
auto_margin_ms = adjust_margin_ms_for_highpass(freq_min)
if margin_ms < auto_margin_ms:
if margin_ms < auto_margin_ms and not skip_warning:
warnings.warn(
f"The provided margin_ms ({margin_ms} ms) is smaller than the recommended margin for the given freq_min ({freq_min} Hz). "
f"This may lead to artifacts at the edges of chunks during processing. "
Expand Down