diff --git a/src/spikeinterface/benchmark/benchmark_base.py b/src/spikeinterface/benchmark/benchmark_base.py index 95674034f1..e8a619305f 100644 --- a/src/spikeinterface/benchmark/benchmark_base.py +++ b/src/spikeinterface/benchmark/benchmark_base.py @@ -457,27 +457,6 @@ def compute_results(self, case_keys=None, verbose=False, **result_params): benchmark.compute_result(**result_params) benchmark.save_result(self.folder / "results" / self.key_to_str(key)) - def create_sorting_analyzer_gt(self, case_keys=None, return_in_uV=True, random_params={}, **job_kwargs): - print("###### Study.create_sorting_analyzer_gt() is not used anymore!!!!!!") - # if case_keys is None: - # case_keys = self.cases.keys() - - # base_folder = self.folder / "sorting_analyzer" - # base_folder.mkdir(exist_ok=True) - - # dataset_keys = [self.cases[key]["dataset"] for key in case_keys] - # dataset_keys = set(dataset_keys) - # for dataset_key in dataset_keys: - # # the waveforms depend on the dataset key - # folder = base_folder / self.key_to_str(dataset_key) - # recording, gt_sorting = self.datasets[dataset_key] - # sorting_analyzer = create_sorting_analyzer( - # gt_sorting, recording, format="binary_folder", folder=folder, return_in_uV=return_in_uV - # ) - # sorting_analyzer.compute("random_spikes", **random_params) - # sorting_analyzer.compute("templates", **job_kwargs) - # sorting_analyzer.compute("noise_levels") - def get_sorting_analyzer(self, case_key=None, dataset_key=None): if case_key is not None: dataset_key = self.cases[case_key]["dataset"] diff --git a/src/spikeinterface/benchmark/benchmark_clustering.py b/src/spikeinterface/benchmark/benchmark_clustering.py index 6cf6cbe7a3..ba9fa53a51 100644 --- a/src/spikeinterface/benchmark/benchmark_clustering.py +++ b/src/spikeinterface/benchmark/benchmark_clustering.py @@ -222,11 +222,21 @@ def plot_performances_ordered(self, *args, **kwargs): return plot_performances_ordered(self, *args, **kwargs) + def plot_some_over_merged(self, *args, **kwargs): + from .benchmark_plot_tools import plot_some_over_merged + + return plot_some_over_merged(self, *args, **kwargs) + + def plot_some_over_splited(self, *args, **kwargs): + from .benchmark_plot_tools import plot_some_over_splited + + return plot_some_over_splited(self, *args, **kwargs) + def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)): if case_keys is None: case_keys = list(self.cases.keys()) - import pylab as plt + import matplotlib.pyplot as plt fig, axes = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) @@ -263,7 +273,7 @@ def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5 if case_keys is None: case_keys = list(self.cases.keys()) - import pylab as plt + import matplotlib.pyplot as plt if axes is None: fig, axes = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) @@ -322,7 +332,7 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs if case_keys is None: case_keys = list(self.cases.keys()) - import pylab as plt + import matplotlib.pyplot as plt fig, axes = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False) @@ -391,81 +401,3 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs fig.colorbar(im, cax=cbar_ax, label=metric) return fig - - def plot_some_over_merged(self, case_keys=None, overmerged_score=0.05, max_units=5, figsize=None): - if case_keys is None: - case_keys = list(self.cases.keys()) - import pylab as plt - - figs = [] - for count, key in enumerate(case_keys): - label = self.cases[key]["label"] - comp = self.get_result(key)["gt_comparison"] - - unit_index = np.flatnonzero(np.sum(comp.agreement_scores.values > overmerged_score, axis=0) > 1) - overmerged_ids = comp.sorting2.unit_ids[unit_index] - - n = min(len(overmerged_ids), max_units) - if n > 0: - fig, axs = plt.subplots(nrows=n, figsize=figsize) - for i, unit_id in enumerate(overmerged_ids[:n]): - gt_unit_indices = np.flatnonzero(comp.agreement_scores.loc[:, unit_id].values > overmerged_score) - gt_unit_ids = comp.sorting1.unit_ids[gt_unit_indices] - ax = axs[i] - ax.set_title(f"unit {unit_id} - GTids {gt_unit_ids}") - - analyzer = self.get_sorting_analyzer(key) - - wf_template = analyzer.get_extension("templates") - templates = wf_template.get_templates(unit_ids=gt_unit_ids) - if analyzer.sparsity is not None: - chan_mask = np.any(analyzer.sparsity.mask[gt_unit_indices, :], axis=0) - templates = templates[:, :, chan_mask] - ax.plot(templates.swapaxes(1, 2).reshape(templates.shape[0], -1).T) - ax.set_xticks([]) - - fig.suptitle(label) - figs.append(fig) - else: - print(key, "no overmerged") - - return figs - - def plot_some_over_splited(self, case_keys=None, oversplit_score=0.05, max_units=5, figsize=None): - if case_keys is None: - case_keys = list(self.cases.keys()) - import pylab as plt - - figs = [] - for count, key in enumerate(case_keys): - label = self.cases[key]["label"] - comp = self.get_result(key)["gt_comparison"] - - gt_unit_indices = np.flatnonzero(np.sum(comp.agreement_scores.values > oversplit_score, axis=1) > 1) - oversplit_ids = comp.sorting1.unit_ids[gt_unit_indices] - - n = min(len(oversplit_ids), max_units) - if n > 0: - fig, axs = plt.subplots(nrows=n, figsize=figsize) - for i, unit_id in enumerate(oversplit_ids[:n]): - unit_indices = np.flatnonzero(comp.agreement_scores.loc[unit_id, :].values > oversplit_score) - unit_ids = comp.sorting2.unit_ids[unit_indices] - ax = axs[i] - ax.set_title(f"Gt unit {unit_id} - unit_ids: {unit_ids}") - - templates = self.get_result(key)["clustering_templates"] - - template_arrays = templates.get_dense_templates()[unit_indices, :, :] - if templates.sparsity is not None: - chan_mask = np.any(templates.sparsity.mask[gt_unit_indices, :], axis=0) - template_arrays = template_arrays[:, :, chan_mask] - - ax.plot(template_arrays.swapaxes(1, 2).reshape(template_arrays.shape[0], -1).T) - ax.set_xticks([]) - - fig.suptitle(label) - figs.append(fig) - else: - print(key, "no over splited") - - return figs diff --git a/src/spikeinterface/benchmark/benchmark_peak_localization.py b/src/spikeinterface/benchmark/benchmark_peak_localization.py index 546beff6bb..8c5afcaacd 100644 --- a/src/spikeinterface/benchmark/benchmark_peak_localization.py +++ b/src/spikeinterface/benchmark/benchmark_peak_localization.py @@ -86,7 +86,7 @@ def plot_comparison_positions(self, case_keys=None): if case_keys is None: case_keys = list(self.cases.keys()) - import pylab as plt + import matplotlib.pyplot as plt fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(15, 5)) @@ -222,7 +222,7 @@ def plot_template_errors(self, case_keys=None, show_probe=True): if case_keys is None: case_keys = list(self.cases.keys()) - import pylab as plt + import matplotlib.pyplot as plt fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(15, 5)) @@ -248,7 +248,7 @@ def plot_comparison_positions(self, case_keys=None): if case_keys is None: case_keys = list(self.cases.keys()) - import pylab as plt + import matplotlib.pyplot as plt fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(15, 5)) @@ -416,7 +416,7 @@ def plot_comparison_positions(self, case_keys=None): # def plot_comparison_precision(benchmarks): -# import pylab as plt +# import matplotlib.pyplot as plt # fig, axes = plt.subplots(ncols=2, nrows=1, figsize=(15, 10), squeeze=False) @@ -487,7 +487,7 @@ def plot_comparison_positions(self, case_keys=None): # norms = np.linalg.norm(benchmark.gt_positions[:, :2], axis=1) # cell_ind = np.argsort(norms)[0] -# import pylab as plt +# import matplotlib.pyplot as plt # fig, axs = plt.subplots(ncols=2, nrows=2, figsize=(15, 10)) # plot_probe_map(benchmark.recording, ax=axs[0, 0]) diff --git a/src/spikeinterface/benchmark/benchmark_plot_tools.py b/src/spikeinterface/benchmark/benchmark_plot_tools.py index 9be89ed8d5..cc35f600b2 100644 --- a/src/spikeinterface/benchmark/benchmark_plot_tools.py +++ b/src/spikeinterface/benchmark/benchmark_plot_tools.py @@ -842,6 +842,7 @@ def plot_performances_comparison( performance_colors={"accuracy": "g", "recall": "b", "precision": "r"}, levels_to_group_by=None, ylim=(-0.1, 1.1), + axs=None, ): """ Plot performances comparison for a study. @@ -881,7 +882,8 @@ def plot_performances_comparison( [key in performance_colors for key in performance_names] ), f"performance_colors must have a color for each performance name: {performance_names}" - fig, axs = plt.subplots(ncols=num_methods - 1, nrows=num_methods - 1, figsize=figsize, squeeze=False) + if axs is None: + fig, axs = plt.subplots(ncols=num_methods - 1, nrows=num_methods - 1, figsize=figsize, squeeze=False) for i, key1 in enumerate(case_keys): for j, key2 in enumerate(case_keys): if i < j: @@ -897,7 +899,8 @@ def plot_performances_comparison( comp1 = study.get_result(sub_key1)["gt_comparison"] comp2 = study.get_result(sub_key2)["gt_comparison"] - for performance_name, color in performance_colors.items(): + for performance_name in performance_names: + color = performance_colors[performance_name] perf1 = comp1.get_performance()[performance_name] perf2 = comp2.get_performance()[performance_name] ax.scatter(perf2, perf1, marker=".", label=performance_name, color=color) @@ -923,9 +926,11 @@ def plot_performances_comparison( patches = [] from matplotlib.patches import Patch - for name, color in performance_colors.items(): - patches.append(Patch(color=color, label=name)) + for performance_name in performance_names: + color = performance_colors[performance_name] + patches.append(Patch(color=color, label=performance_name)) ax.legend(handles=patches) + fig = ax.figure fig.subplots_adjust(hspace=0.1, wspace=0.1) return fig @@ -964,7 +969,7 @@ def plot_performances_vs_depth_and_snr( fig : matplotlib.figure.Figure The resulting figure containing the plots. """ - import pylab as plt + import matplotlib.pyplot as plt if case_keys is None: case_keys = list(study.cases.keys()) @@ -1082,3 +1087,102 @@ def plot_performance_losses( despine(axs) return fig + + +def plot_some_over_merged(study, case_keys=None, overmerged_score=0.05, max_units=5, figsize=None): + """ + Plot some waveforms of overmerged units. + """ + + if case_keys is None: + case_keys = list(study.cases.keys()) + import matplotlib.pyplot as plt + + figs = [] + for count, key in enumerate(case_keys): + label = study.cases[key]["label"] + comp = study.get_result(key)["gt_comparison"] + + unit_index = np.flatnonzero(np.sum(comp.agreement_scores.values > overmerged_score, axis=0) > 1) + overmerged_ids = comp.sorting2.unit_ids[unit_index] + + n = min(len(overmerged_ids), max_units) + if n > 0: + fig, axs = plt.subplots(nrows=n, figsize=figsize, squeeze=False) + axs = axs[:, 0] + for i, unit_id in enumerate(overmerged_ids[:n]): + gt_unit_indices = np.flatnonzero(comp.agreement_scores.loc[:, unit_id].values > overmerged_score) + gt_unit_ids = comp.sorting1.unit_ids[gt_unit_indices] + ax = axs[i] + ax.set_title(f"unit {unit_id} - GTids {gt_unit_ids}") + + analyzer = study.get_sorting_analyzer(key) + + wf_template = analyzer.get_extension("templates") + templates = wf_template.get_templates(unit_ids=gt_unit_ids) + if analyzer.sparsity is not None: + chan_mask = np.any(analyzer.sparsity.mask[gt_unit_indices, :], axis=0) + templates = templates[:, :, chan_mask] + ax.plot(templates.swapaxes(1, 2).reshape(templates.shape[0], -1).T) + ax.set_xticks([]) + + fig.suptitle(label) + figs.append(fig) + else: + print(key, "no overmerged") + + return figs + + +def plot_some_over_splited(study, case_keys=None, oversplit_score=0.05, max_units=5, figsize=None): + """ + Plot some waveforms of over-splitted units. + """ + if case_keys is None: + case_keys = list(study.cases.keys()) + import matplotlib.pyplot as plt + + print(case_keys) + figs = [] + for count, key in enumerate(case_keys): + print(key) + label = study.cases[key]["label"] + comp = study.get_result(key)["gt_comparison"] + + gt_unit_indices = np.flatnonzero(np.sum(comp.agreement_scores.values > oversplit_score, axis=1) > 1) + oversplit_ids = comp.sorting1.unit_ids[gt_unit_indices] + + n = min(len(oversplit_ids), max_units) + if n > 0: + fig, axs = plt.subplots(nrows=n, figsize=figsize, squeeze=False) + axs = axs[:, 0] + for i, unit_id in enumerate(oversplit_ids[:n]): + unit_indices = np.flatnonzero(comp.agreement_scores.loc[unit_id, :].values > oversplit_score) + unit_ids = comp.sorting2.unit_ids[unit_indices] + ax = axs[i] + ax.set_title(f"Gt unit {unit_id} - unit_ids: {unit_ids}") + + results = study.get_result(key) + if "clustering_templates" in results: + # ClusteringBenchmark has this + templates = results["clustering_templates"] + elif "sorter_analyzer" in results: + # SorterBenchmark has this + templates = results["sorter_analyzer"].get_extension("templates").get_data(outputs="Templates") + else: + raise ValueError("This benchmark do not have templates computed") + + template_arrays = templates.get_dense_templates()[unit_indices, :, :] + if templates.sparsity is not None: + chan_mask = np.any(templates.sparsity.mask[gt_unit_indices, :], axis=0) + template_arrays = template_arrays[:, :, chan_mask] + + ax.plot(template_arrays.swapaxes(1, 2).reshape(template_arrays.shape[0], -1).T) + ax.set_xticks([]) + + fig.suptitle(label) + figs.append(fig) + else: + print(key, "no over splited") + + return figs diff --git a/src/spikeinterface/benchmark/benchmark_sorter.py b/src/spikeinterface/benchmark/benchmark_sorter.py index 6f50800376..a49b0d4e7c 100644 --- a/src/spikeinterface/benchmark/benchmark_sorter.py +++ b/src/spikeinterface/benchmark/benchmark_sorter.py @@ -3,7 +3,7 @@ """ import numpy as np -from spikeinterface.core import NumpySorting +from spikeinterface.core import NumpySorting, create_sorting_analyzer from .benchmark_base import Benchmark, BenchmarkStudy, MixinStudyUnitCount from spikeinterface.sorters import run_sorter from spikeinterface.comparison import compare_sorter_to_ground_truth @@ -25,7 +25,7 @@ def run(self): sorting = NumpySorting.from_sorting(raw_sorting) self.result = {"sorting": sorting} - def compute_result(self, match_score=0.5, exhaustive_gt=True): + def compute_result(self, match_score=0.5, exhaustive_gt=True, with_analyzer=False, **job_kwargs): # run becnhmark result sorting = self.result["sorting"] comp = compare_sorter_to_ground_truth( @@ -33,11 +33,19 @@ def compute_result(self, match_score=0.5, exhaustive_gt=True): ) self.result["gt_comparison"] = comp + if with_analyzer: + # optionally computes analyzer to have templates for oversplited/overmerged + analyzer = create_sorting_analyzer(sorting, self.recording, format="memory", sparse=True, **job_kwargs) + analyzer.compute("random_spikes") + analyzer.compute("templates", **job_kwargs) + self.result["sorter_analyzer"] = analyzer + _run_key_saved = [ ("sorting", "sorting"), ] _result_key_saved = [ ("gt_comparison", "pickle"), + ("sorter_analyzer", "sorting_analyzer"), ] @@ -112,3 +120,13 @@ def plot_performance_losses(self, *args, **kwargs): from .benchmark_plot_tools import plot_performance_losses return plot_performance_losses(self, *args, **kwargs) + + def plot_some_over_merged(self, *args, **kwargs): + from .benchmark_plot_tools import plot_some_over_merged + + return plot_some_over_merged(self, *args, **kwargs) + + def plot_some_over_splited(self, *args, **kwargs): + from .benchmark_plot_tools import plot_some_over_splited + + return plot_some_over_splited(self, *args, **kwargs) diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py index 1d313aaf23..be1cf18fbf 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_clustering.py @@ -52,7 +52,6 @@ def test_benchmark_clustering(create_cache_folder): print(study) # this study needs analyzer - # study.create_sorting_analyzer_gt(**job_kwargs) study.compute_metrics() study = ClusteringStudy(study_folder) diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_matching.py b/src/spikeinterface/benchmark/tests/test_benchmark_matching.py index 09a263e34e..22b519ec78 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_matching.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_matching.py @@ -52,7 +52,6 @@ def test_benchmark_matching(create_cache_folder): print(study) # this study needs analyzer - # study.create_sorting_analyzer_gt(**job_kwargs) study.compute_metrics() # run and result diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_motion_interpolation.py b/src/spikeinterface/benchmark/tests/test_benchmark_motion_interpolation.py index a9b64d19ea..059c416ab9 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_motion_interpolation.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_motion_interpolation.py @@ -113,7 +113,6 @@ def test_benchmark_motion_interpolation(create_cache_folder): study = MotionInterpolationStudy.create(study_folder, datasets, cases) # this study needs analyzer - study.create_sorting_analyzer_gt(**job_kwargs) study.compute_metrics() # run and result diff --git a/src/spikeinterface/benchmark/tests/test_benchmark_peak_localization.py b/src/spikeinterface/benchmark/tests/test_benchmark_peak_localization.py index dc4527b761..b4938a0b59 100644 --- a/src/spikeinterface/benchmark/tests/test_benchmark_peak_localization.py +++ b/src/spikeinterface/benchmark/tests/test_benchmark_peak_localization.py @@ -39,7 +39,6 @@ def test_benchmark_peak_localization(create_cache_folder): print(study) # this study needs analyzer - study.create_sorting_analyzer_gt(**job_kwargs) study.compute_metrics() # run and result @@ -81,7 +80,6 @@ def test_benchmark_unit_locations(create_cache_folder): print(study) # this study needs analyzer - study.create_sorting_analyzer_gt(**job_kwargs) study.compute_metrics() # run and result diff --git a/src/spikeinterface/sortingcomponents/clustering/cleaning_tools.py b/src/spikeinterface/sortingcomponents/clustering/cleaning_tools.py index 8028761ccb..4c939695f4 100644 --- a/src/spikeinterface/sortingcomponents/clustering/cleaning_tools.py +++ b/src/spikeinterface/sortingcomponents/clustering/cleaning_tools.py @@ -642,7 +642,7 @@ def detect_mixtures(templates, method_kwargs={}, job_kwargs={}, tmp_folder=None, similar_templates[1] += [unit_ids[i]] if DEBUG: - import pylab as plt + import matplotlib.pyplot as plt fig, axes = plt.subplots(1, 2) from spikeinterface.widgets import plot_traces