diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index d0a3deac81..057ca0486d 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -1137,12 +1137,16 @@ def _load_extractor_from_dict(dic) -> BaseExtractor: extractor_class = _get_class_from_string(class_name) assert extractor_class is not None and class_name is not None, "Could not load spikeinterface class" - if not _check_same_version(class_name, dic["version"]): + is_old_version = not _check_same_version(class_name, dic["version"]) + if is_old_version: warnings.warn( f"Versions are not the same. This might lead to compatibility errors. " f"Using {class_name.split('.')[0]}=={dic['version']} is recommended" ) + if hasattr(extractor_class, "_handle_backward_compatibility"): + new_kwargs = extractor_class._handle_backward_compatibility(new_kwargs, dic) + # Initialize the extractor extractor = extractor_class(**new_kwargs) diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index d29ae756ca..ef369662e7 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -4,9 +4,9 @@ import numpy as np from spikeinterface.core.core_tools import define_function_handling_dict_from_class -from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment +from spikeinterface.core import get_chunk_with_margin, ensure_chunk_size, get_global_job_kwargs -from spikeinterface.core import get_chunk_with_margin +from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment HIGHPASS_ERROR_THRESHOLD_HZ = 100 MARGIN_TO_CHUNK_PERCENT_WARNING = 0.2 # 20% @@ -117,6 +117,14 @@ def __init__( assert margin_ms is not None, "margin_ms must be provided!" margin = int(margin_ms * fs / 1000.0) + + global_job_kwargs_chunk_size = ensure_chunk_size(recording, **get_global_job_kwargs()) + if margin > MARGIN_TO_CHUNK_PERCENT_WARNING * global_job_kwargs_chunk_size: + warnings.warn( + f"The margin size ({margin} samples) is more than {int(MARGIN_TO_CHUNK_PERCENT_WARNING * 100)}% " + f"of the global chunk size {global_job_kwargs_chunk_size} samples. This may lead to performance bottlenecks when " + f"chunking. Consider increasing the chunk_size or chunk_duration to minimize margin overhead." + ) self.margin_samples = margin for parent_segment in recording._recording_segments: self.add_recording_segment( @@ -166,12 +174,6 @@ def __init__( self.dtype = dtype def get_traces(self, start_frame, end_frame, channel_indices): - if self.margin > MARGIN_TO_CHUNK_PERCENT_WARNING * (end_frame - start_frame): - warnings.warn( - f"The margin size ({self.margin} samples) is more than {int(MARGIN_TO_CHUNK_PERCENT_WARNING * 100)}% " - f"of the chunk size {(end_frame - start_frame)} samples. This may lead to performance bottlenecks when " - f"chunking. Consider increasing the chunk size to minimize margin overhead." - ) traces_chunk, left_margin, right_margin = get_chunk_with_margin( self.parent_recording_segment, start_frame, @@ -253,11 +255,17 @@ def __init__( margin_ms="auto", dtype=None, ignore_low_freq_error=False, + _skip_margin_warning_for_old_version=False, **filter_kwargs, ): if margin_ms == "auto": margin_ms = adjust_margin_ms_for_highpass(freq_min) - highpass_check(freq_min, margin_ms, ignore_low_freq_error=ignore_low_freq_error) + highpass_check( + freq_min, + margin_ms, + ignore_low_freq_error=ignore_low_freq_error, + skip_warning=_skip_margin_warning_for_old_version, + ) FilterRecording.__init__( self, recording, band=[freq_min, freq_max], margin_ms=margin_ms, dtype=dtype, **filter_kwargs ) @@ -272,6 +280,18 @@ def __init__( ) self._kwargs.update(filter_kwargs) + @classmethod + def _handle_backward_compatibility(cls, old_kwargs, full_dict): + new_kwargs = old_kwargs.copy() + is_lfp_case = old_kwargs["freq_min"] < HIGHPASS_ERROR_THRESHOLD_HZ + if "ignore_low_freq_error" not in new_kwargs: + new_kwargs["ignore_low_freq_error"] = True + if is_lfp_case: + new_kwargs["_skip_margin_warning_for_old_version"] = False + else: + new_kwargs["_skip_margin_warning_for_old_version"] = True + return new_kwargs + class HighpassFilterRecording(FilterRecording): """ @@ -299,11 +319,23 @@ class HighpassFilterRecording(FilterRecording): """ def __init__( - self, recording, freq_min=300.0, margin_ms="auto", dtype=None, ignore_low_freq_error=False, **filter_kwargs + self, + recording, + freq_min=300.0, + margin_ms="auto", + dtype=None, + ignore_low_freq_error=False, + _skip_margin_warning_for_old_version=False, + **filter_kwargs, ): if margin_ms == "auto": margin_ms = adjust_margin_ms_for_highpass(freq_min) - highpass_check(freq_min, margin_ms, ignore_low_freq_error=ignore_low_freq_error) + highpass_check( + freq_min, + margin_ms, + ignore_low_freq_error=ignore_low_freq_error, + skip_warning=_skip_margin_warning_for_old_version, + ) FilterRecording.__init__( self, recording, band=freq_min, margin_ms=margin_ms, dtype=dtype, btype="highpass", **filter_kwargs ) @@ -311,6 +343,18 @@ def __init__( self._kwargs = dict(recording=recording, freq_min=freq_min, margin_ms=margin_ms, dtype=dtype.str) self._kwargs.update(filter_kwargs) + @classmethod + def _handle_backward_compatibility(cls, old_kwargs, full_dict): + new_kwargs = old_kwargs.copy() + is_lfp_case = old_kwargs["freq_min"] < HIGHPASS_ERROR_THRESHOLD_HZ + if "ignore_low_freq_error" not in new_kwargs: + new_kwargs["ignore_low_freq_error"] = True + if is_lfp_case: + new_kwargs["_skip_margin_warning_for_old_version"] = False + else: + new_kwargs["_skip_margin_warning_for_old_version"] = True + return new_kwargs + class NotchFilterRecording(FilterRecording): """ @@ -446,7 +490,7 @@ def adjust_margin_ms_for_notch(q, f0, multiplier=5): return margin_ms -def highpass_check(freq_min, margin_ms, ignore_low_freq_error=False): +def highpass_check(freq_min, margin_ms, ignore_low_freq_error=False, skip_warning=False): if freq_min < HIGHPASS_ERROR_THRESHOLD_HZ: if not ignore_low_freq_error: raise ValueError( @@ -460,7 +504,7 @@ def highpass_check(freq_min, margin_ms, ignore_low_freq_error=False): margin_ms = adjust_margin_ms_for_highpass(freq_min) else: auto_margin_ms = adjust_margin_ms_for_highpass(freq_min) - if margin_ms < auto_margin_ms: + if margin_ms < auto_margin_ms and not skip_warning: warnings.warn( f"The provided margin_ms ({margin_ms} ms) is smaller than the recommended margin for the given freq_min ({freq_min} Hz). " f"This may lead to artifacts at the edges of chunks during processing. "