Skip to content
32 changes: 25 additions & 7 deletions src/spikeinterface/sorters/internal/lupin.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,10 @@ class LupinSorter(ComponentsBasedSorter):
"clustering_ms_after": 1.3,
"whitening_radius_um": 100.0,
"detection_radius_um": 50.0,
"features_radius_um": 75.0,
"features_radius_um": 120.0,
"split_radius_um" : 60.0,
"template_radius_um": 100.0,
"merge_similarity_lag_ms": 0.5,
"freq_min": 150.0,
"freq_max": 7000.0,
"cache_preprocessing_mode": "auto",
Expand All @@ -55,9 +57,10 @@ class LupinSorter(ComponentsBasedSorter):
"clustering_recursive_depth": 3,
"ms_before": 1.0,
"ms_after": 2.5,
"template_sparsify_threshold": 1.5,
"template_sparsify_threshold": 1.,
"template_min_snr_ptp": 4.0,
"template_max_jitter_ms": 0.2,
"template_matching_engine": "circus-omp",
"min_firing_rate": 0.1,
"gather_mode": "memory",
"job_kwargs": {},
Expand All @@ -74,6 +77,11 @@ class LupinSorter(ComponentsBasedSorter):
"clustering_ms_before": "Milliseconds before the spike peak for clustering",
"clustering_ms_after": "Milliseconds after the spike peak for clustering",
"radius_um": "Radius for sparsity",
"whitening_radius_um": "Radius for whitening",
"detection_radius_um": "Radius for peak detection",
"features_radius_um": "Radius for sparsity in SVD features",
"split_radius_um" : "Radius for the local split clustering",
"template_radius_um": "Radius for the sparsity of template before template matching",
"freq_min": "Low frequency",
"freq_max": "High frequency",
"peak_sign": "Sign of peaks neg/pos/both",
Expand All @@ -99,7 +107,7 @@ class LupinSorter(ComponentsBasedSorter):

@classmethod
def get_sorter_version(cls):
return "2025.12"
return "2026.01"

@classmethod
def _run_from_folder(cls, sorter_output_folder, params, verbose):
Expand Down Expand Up @@ -201,6 +209,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
dtype="float32",
mode="local",
radius_um=params["whitening_radius_um"],
seed=seed,
)

if params["apply_motion_correction"]:
Expand All @@ -219,18 +228,20 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

# Cache in mem or folder
cache_folder = sorter_output_folder / "cache_preprocessing"
recording_pre_cache = recording
recording, cache_info = cache_preprocessing(
recording,
mode=params["cache_preprocessing_mode"],
folder=cache_folder,
job_kwargs=job_kwargs,
)

noise_levels = get_noise_levels(recording, return_in_uV=False)

else:
recording = recording_raw.astype("float32")
noise_levels = get_noise_levels(recording, return_in_uV=False)
cache_info = None

noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=dict(seed=seed))

# detection
ms_before = params["ms_before"]
Expand Down Expand Up @@ -265,20 +276,26 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if verbose:
print(f"select_peaks(): {len(peaks)} peaks kept for clustering")

num_shifts_merging = int(sampling_frequency * params["merge_similarity_lag_ms"] / 1000.)

# Clustering
clustering_kwargs = deepcopy(clustering_methods["iterative-isosplit"]._default_params)
clustering_kwargs["peaks_svd"]["ms_before"] = params["clustering_ms_before"]
clustering_kwargs["peaks_svd"]["ms_after"] = params["clustering_ms_after"]
clustering_kwargs["peaks_svd"]["radius_um"] = params["features_radius_um"]
clustering_kwargs["peaks_svd"]["n_components"] = params["n_svd_components_per_channel"]
clustering_kwargs["split"]["split_radius_um"] = params["split_radius_um"]
clustering_kwargs["split"]["recursive_depth"] = params["clustering_recursive_depth"]
clustering_kwargs["split"]["method_kwargs"]["n_pca_features"] = params["n_pca_features"]
clustering_kwargs["clean_templates"]["sparsify_threshold"] = params["template_sparsify_threshold"]
clustering_kwargs["clean_templates"]["min_snr"] = params["template_min_snr_ptp"]
clustering_kwargs["clean_templates"]["max_jitter_ms"] = params["template_max_jitter_ms"]
clustering_kwargs["merge_from_templates"]["use_lags"] = True
clustering_kwargs["merge_from_templates"]["num_shifts"] = num_shifts_merging
clustering_kwargs["noise_levels"] = noise_levels
clustering_kwargs["clean_low_firing"]["min_firing_rate"] = params["min_firing_rate"]
clustering_kwargs["clean_low_firing"]["subsampling_factor"] = all_peaks.size / peaks.size
clustering_kwargs["seed"] = seed

if params["debug"]:
clustering_kwargs["debug_folder"] = sorter_output_folder
Expand Down Expand Up @@ -353,7 +370,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
spikes = find_spikes_from_templates(
recording,
templates,
method="wobble",
method=params["template_matching_engine"],
method_kwargs={},
pipeline_kwargs=pipeline_kwargs,
job_kwargs=job_kwargs,
Expand All @@ -377,7 +394,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
templates,
amplitude_scalings=spikes["amplitude"],
noise_levels=noise_levels,
similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": 0.1},
similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": params["merge_similarity_lag_ms"]},
sparsity_overlap=0.5,
censor_ms=3.0,
max_distance_um=50,
Expand All @@ -396,6 +413,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
np.save(sorter_output_folder / "spikes.npy", spikes)
templates.to_zarr(sorter_output_folder / "templates.zarr")
if analyzer_final is not None:
analyzer_final._recording = recording_pre_cache
analyzer_final.save_as(format="binary_folder", folder=sorter_output_folder / "analyzer")

sorting = sorting.save(folder=sorter_output_folder / "sorting")
Expand Down
24 changes: 19 additions & 5 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
"clustering_ms_before": 0.5,
"clustering_ms_after": 1.5,
"detection_radius_um": 150.0,
"features_radius_um": 75.0,
"features_radius_um": 120.0,
"split_radius_um" : 60.0,
"template_radius_um": 100.0,
"merge_similarity_lag_ms": 0.5,
"freq_min": 150.0,
"freq_max": 6000.0,
"cache_preprocessing_mode": "auto",
Expand Down Expand Up @@ -69,6 +71,10 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
"clustering_ms_before": "Milliseconds before the spike peak for clustering",
"clustering_ms_after": "Milliseconds after the spike peak for clustering",
"radius_um": "Radius for sparsity",
"detection_radius_um": "Radius for peak detection",
"features_radius_um": "Radius for sparsity in SVD features",
"split_radius_um" : "Radius for the local split clustering",
"template_radius_um": "Radius for the sparsity of template before template matching",
"freq_min": "Low frequency for bandpass filter",
"freq_max": "High frequency for bandpass filter",
"peak_sign": "Sign of peaks neg/pos/both",
Expand All @@ -94,7 +100,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter):

@classmethod
def get_sorter_version(cls):
return "2025.12"
return "2026.01"

@classmethod
def _run_from_folder(cls, sorter_output_folder, params, verbose):
Expand Down Expand Up @@ -182,6 +188,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

# Cache in mem or folder
cache_folder = sorter_output_folder / "cache_preprocessing"
recording_pre_cache = recording
recording, cache_info = cache_preprocessing(
recording,
mode=params["cache_preprocessing_mode"],
Expand All @@ -191,8 +198,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

noise_levels = np.ones(num_chans, dtype="float32")
else:
recording_pre_cache = recording
recording = recording_raw.astype("float32")
noise_levels = get_noise_levels(recording, return_in_uV=False)
noise_levels = get_noise_levels(recording, return_in_uV=False, random_slices_kwargs=dict(seed=seed))
cache_info = None

# detection
Expand Down Expand Up @@ -225,6 +233,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
# )

# Clustering
num_shifts_merging = int(sampling_frequency * params["merge_similarity_lag_ms"] / 1000.)

clustering_kwargs = deepcopy(clustering_methods["iterative-isosplit"]._default_params)
clustering_kwargs["peaks_svd"]["ms_before"] = params["clustering_ms_before"]
clustering_kwargs["peaks_svd"]["ms_after"] = params["clustering_ms_after"]
Expand All @@ -235,9 +245,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
clustering_kwargs["clean_templates"]["sparsify_threshold"] = params["template_sparsify_threshold"]
clustering_kwargs["clean_templates"]["min_snr"] = params["template_min_snr_ptp"]
clustering_kwargs["clean_templates"]["max_jitter_ms"] = params["template_max_jitter_ms"]
clustering_kwargs["merge_from_templates"]["num_shifts"] = num_shifts_merging
clustering_kwargs["noise_levels"] = noise_levels
clustering_kwargs["clean_low_firing"]["min_firing_rate"] = params["min_firing_rate"]
clustering_kwargs["clean_low_firing"]["subsampling_factor"] = all_peaks.size / peaks.size
clustering_kwargs["seed"] = seed

if params["debug"]:
clustering_kwargs["debug_folder"] = sorter_output_folder
Expand Down Expand Up @@ -331,7 +343,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
final_spikes["segment_index"] = spikes["segment_index"]
sorting = NumpySorting(final_spikes, sampling_frequency, templates.unit_ids)

auto_merge = True
# auto_merge = True
auto_merge = False
analyzer_final = None
if auto_merge:
from spikeinterface.sorters.internal.spyking_circus2 import final_cleaning_circus
Expand All @@ -342,7 +355,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
templates,
amplitude_scalings=spikes["amplitude"],
noise_levels=noise_levels,
similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": 0.1},
similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": params["merge_similarity_lag_ms"]},
sparsity_overlap=0.5,
censor_ms=3.0,
max_distance_um=50,
Expand All @@ -362,6 +375,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
np.save(sorter_output_folder / "spikes.npy", spikes)
templates.to_zarr(sorter_output_folder / "templates.zarr")
if analyzer_final is not None:
analyzer_final._recording = recording_pre_cache
analyzer_final.save_as(format="binary_folder", folder=sorter_output_folder / "analyzer")

sorting = sorting.save(folder=sorter_output_folder / "sorting")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class IterativeISOSPLITClustering:
"similarity_metric": "l1",
"num_shifts": 3,
"similarity_thresh": 0.8,
"use_lags": True,
},
"merge_from_features": None,
# "merge_from_features": {"merge_radius_um": 60.0},
Expand Down Expand Up @@ -106,13 +107,16 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):

ms_before = params["peaks_svd"]["ms_before"]
ms_after = params["peaks_svd"]["ms_after"]
nbefore = int(ms_before * recording.sampling_frequency / 1000.)
nafter = int(ms_after * recording.sampling_frequency / 1000.)

# radius_um = params["waveforms"]["radius_um"]
verbose = params["verbose"]

debug_folder = params["debug_folder"]

params_peak_svd = params["peaks_svd"].copy()

params_peak_svd["seed"] = params["seed"]
motion = params_peak_svd["motion"]
motion_aware = motion is not None

Expand Down Expand Up @@ -285,13 +289,17 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
post_merge_label1 = post_split_label.copy()

if params["merge_from_templates"] is not None:
params_merge_from_templates = params["merge_from_templates"].copy()
num_shifts = params_merge_from_templates["num_shifts"]
num_shifts = min((num_shifts, nbefore, nafter))
params_merge_from_templates["num_shifts"] = num_shifts
post_merge_label2, templates_array, template_sparse_mask, unit_ids = merge_peak_labels_from_templates(
peaks,
post_merge_label1,
unit_ids,
templates_array,
template_sparse_mask,
**params["merge_from_templates"],
**params_merge_from_templates,
)
else:
post_merge_label2 = post_merge_label1.copy()
Expand Down