Skip to content
21 changes: 0 additions & 21 deletions src/spikeinterface/benchmark/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,27 +457,6 @@ def compute_results(self, case_keys=None, verbose=False, **result_params):
benchmark.compute_result(**result_params)
benchmark.save_result(self.folder / "results" / self.key_to_str(key))

def create_sorting_analyzer_gt(self, case_keys=None, return_in_uV=True, random_params={}, **job_kwargs):
print("###### Study.create_sorting_analyzer_gt() is not used anymore!!!!!!")
# if case_keys is None:
# case_keys = self.cases.keys()

# base_folder = self.folder / "sorting_analyzer"
# base_folder.mkdir(exist_ok=True)

# dataset_keys = [self.cases[key]["dataset"] for key in case_keys]
# dataset_keys = set(dataset_keys)
# for dataset_key in dataset_keys:
# # the waveforms depend on the dataset key
# folder = base_folder / self.key_to_str(dataset_key)
# recording, gt_sorting = self.datasets[dataset_key]
# sorting_analyzer = create_sorting_analyzer(
# gt_sorting, recording, format="binary_folder", folder=folder, return_in_uV=return_in_uV
# )
# sorting_analyzer.compute("random_spikes", **random_params)
# sorting_analyzer.compute("templates", **job_kwargs)
# sorting_analyzer.compute("noise_levels")

def get_sorting_analyzer(self, case_key=None, dataset_key=None):
if case_key is not None:
dataset_key = self.cases[case_key]["dataset"]
Expand Down
94 changes: 13 additions & 81 deletions src/spikeinterface/benchmark/benchmark_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,21 @@ def plot_performances_ordered(self, *args, **kwargs):

return plot_performances_ordered(self, *args, **kwargs)

def plot_some_over_merged(self, *args, **kwargs):
from .benchmark_plot_tools import plot_some_over_merged

return plot_some_over_merged(self, *args, **kwargs)

def plot_some_over_splited(self, *args, **kwargs):
from .benchmark_plot_tools import plot_some_over_splited

return plot_some_over_splited(self, *args, **kwargs)

def plot_error_metrics(self, metric="cosine", case_keys=None, figsize=(15, 5)):

if case_keys is None:
case_keys = list(self.cases.keys())
import pylab as plt
import matplotlib.pyplot as plt

fig, axes = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)

Expand Down Expand Up @@ -263,7 +273,7 @@ def plot_metrics_vs_snr(self, metric="agreement", case_keys=None, figsize=(15, 5

if case_keys is None:
case_keys = list(self.cases.keys())
import pylab as plt
import matplotlib.pyplot as plt

if axes is None:
fig, axes = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)
Expand Down Expand Up @@ -322,7 +332,7 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs

if case_keys is None:
case_keys = list(self.cases.keys())
import pylab as plt
import matplotlib.pyplot as plt

fig, axes = plt.subplots(ncols=len(case_keys), nrows=1, figsize=figsize, squeeze=False)

Expand Down Expand Up @@ -391,81 +401,3 @@ def plot_metrics_vs_depth_and_snr(self, metric="agreement", case_keys=None, figs
fig.colorbar(im, cax=cbar_ax, label=metric)

return fig

def plot_some_over_merged(self, case_keys=None, overmerged_score=0.05, max_units=5, figsize=None):
if case_keys is None:
case_keys = list(self.cases.keys())
import pylab as plt

figs = []
for count, key in enumerate(case_keys):
label = self.cases[key]["label"]
comp = self.get_result(key)["gt_comparison"]

unit_index = np.flatnonzero(np.sum(comp.agreement_scores.values > overmerged_score, axis=0) > 1)
overmerged_ids = comp.sorting2.unit_ids[unit_index]

n = min(len(overmerged_ids), max_units)
if n > 0:
fig, axs = plt.subplots(nrows=n, figsize=figsize)
for i, unit_id in enumerate(overmerged_ids[:n]):
gt_unit_indices = np.flatnonzero(comp.agreement_scores.loc[:, unit_id].values > overmerged_score)
gt_unit_ids = comp.sorting1.unit_ids[gt_unit_indices]
ax = axs[i]
ax.set_title(f"unit {unit_id} - GTids {gt_unit_ids}")

analyzer = self.get_sorting_analyzer(key)

wf_template = analyzer.get_extension("templates")
templates = wf_template.get_templates(unit_ids=gt_unit_ids)
if analyzer.sparsity is not None:
chan_mask = np.any(analyzer.sparsity.mask[gt_unit_indices, :], axis=0)
templates = templates[:, :, chan_mask]
ax.plot(templates.swapaxes(1, 2).reshape(templates.shape[0], -1).T)
ax.set_xticks([])

fig.suptitle(label)
figs.append(fig)
else:
print(key, "no overmerged")

return figs

def plot_some_over_splited(self, case_keys=None, oversplit_score=0.05, max_units=5, figsize=None):
if case_keys is None:
case_keys = list(self.cases.keys())
import pylab as plt

figs = []
for count, key in enumerate(case_keys):
label = self.cases[key]["label"]
comp = self.get_result(key)["gt_comparison"]

gt_unit_indices = np.flatnonzero(np.sum(comp.agreement_scores.values > oversplit_score, axis=1) > 1)
oversplit_ids = comp.sorting1.unit_ids[gt_unit_indices]

n = min(len(oversplit_ids), max_units)
if n > 0:
fig, axs = plt.subplots(nrows=n, figsize=figsize)
for i, unit_id in enumerate(oversplit_ids[:n]):
unit_indices = np.flatnonzero(comp.agreement_scores.loc[unit_id, :].values > oversplit_score)
unit_ids = comp.sorting2.unit_ids[unit_indices]
ax = axs[i]
ax.set_title(f"Gt unit {unit_id} - unit_ids: {unit_ids}")

templates = self.get_result(key)["clustering_templates"]

template_arrays = templates.get_dense_templates()[unit_indices, :, :]
if templates.sparsity is not None:
chan_mask = np.any(templates.sparsity.mask[gt_unit_indices, :], axis=0)
template_arrays = template_arrays[:, :, chan_mask]

ax.plot(template_arrays.swapaxes(1, 2).reshape(template_arrays.shape[0], -1).T)
ax.set_xticks([])

fig.suptitle(label)
figs.append(fig)
else:
print(key, "no over splited")

return figs
10 changes: 5 additions & 5 deletions src/spikeinterface/benchmark/benchmark_peak_localization.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def plot_comparison_positions(self, case_keys=None):

if case_keys is None:
case_keys = list(self.cases.keys())
import pylab as plt
import matplotlib.pyplot as plt

fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(15, 5))

Expand Down Expand Up @@ -222,7 +222,7 @@ def plot_template_errors(self, case_keys=None, show_probe=True):

if case_keys is None:
case_keys = list(self.cases.keys())
import pylab as plt
import matplotlib.pyplot as plt

fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(15, 5))

Expand All @@ -248,7 +248,7 @@ def plot_comparison_positions(self, case_keys=None):

if case_keys is None:
case_keys = list(self.cases.keys())
import pylab as plt
import matplotlib.pyplot as plt

fig, axs = plt.subplots(ncols=3, nrows=1, figsize=(15, 5))

Expand Down Expand Up @@ -416,7 +416,7 @@ def plot_comparison_positions(self, case_keys=None):


# def plot_comparison_precision(benchmarks):
# import pylab as plt
# import matplotlib.pyplot as plt

# fig, axes = plt.subplots(ncols=2, nrows=1, figsize=(15, 10), squeeze=False)

Expand Down Expand Up @@ -487,7 +487,7 @@ def plot_comparison_positions(self, case_keys=None):
# norms = np.linalg.norm(benchmark.gt_positions[:, :2], axis=1)
# cell_ind = np.argsort(norms)[0]

# import pylab as plt
# import matplotlib.pyplot as plt

# fig, axs = plt.subplots(ncols=2, nrows=2, figsize=(15, 10))
# plot_probe_map(benchmark.recording, ax=axs[0, 0])
Expand Down
114 changes: 109 additions & 5 deletions src/spikeinterface/benchmark/benchmark_plot_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,7 @@ def plot_performances_comparison(
performance_colors={"accuracy": "g", "recall": "b", "precision": "r"},
levels_to_group_by=None,
ylim=(-0.1, 1.1),
axs=None,
):
"""
Plot performances comparison for a study.
Expand Down Expand Up @@ -881,7 +882,8 @@ def plot_performances_comparison(
[key in performance_colors for key in performance_names]
), f"performance_colors must have a color for each performance name: {performance_names}"

fig, axs = plt.subplots(ncols=num_methods - 1, nrows=num_methods - 1, figsize=figsize, squeeze=False)
if axs is None:
fig, axs = plt.subplots(ncols=num_methods - 1, nrows=num_methods - 1, figsize=figsize, squeeze=False)
for i, key1 in enumerate(case_keys):
for j, key2 in enumerate(case_keys):
if i < j:
Expand All @@ -897,7 +899,8 @@ def plot_performances_comparison(
comp1 = study.get_result(sub_key1)["gt_comparison"]
comp2 = study.get_result(sub_key2)["gt_comparison"]

for performance_name, color in performance_colors.items():
for performance_name in performance_names:
color = performance_colors[performance_name]
perf1 = comp1.get_performance()[performance_name]
perf2 = comp2.get_performance()[performance_name]
ax.scatter(perf2, perf1, marker=".", label=performance_name, color=color)
Expand All @@ -923,9 +926,11 @@ def plot_performances_comparison(
patches = []
from matplotlib.patches import Patch

for name, color in performance_colors.items():
patches.append(Patch(color=color, label=name))
for performance_name in performance_names:
color = performance_colors[performance_name]
patches.append(Patch(color=color, label=performance_name))
ax.legend(handles=patches)
fig = ax.figure
fig.subplots_adjust(hspace=0.1, wspace=0.1)
return fig

Expand Down Expand Up @@ -964,7 +969,7 @@ def plot_performances_vs_depth_and_snr(
fig : matplotlib.figure.Figure
The resulting figure containing the plots.
"""
import pylab as plt
import matplotlib.pyplot as plt

if case_keys is None:
case_keys = list(study.cases.keys())
Expand Down Expand Up @@ -1082,3 +1087,102 @@ def plot_performance_losses(
despine(axs)

return fig


def plot_some_over_merged(study, case_keys=None, overmerged_score=0.05, max_units=5, figsize=None):
"""
Plot some waveforms of overmerged units.
"""

if case_keys is None:
case_keys = list(study.cases.keys())
import matplotlib.pyplot as plt

figs = []
for count, key in enumerate(case_keys):
label = study.cases[key]["label"]
comp = study.get_result(key)["gt_comparison"]

unit_index = np.flatnonzero(np.sum(comp.agreement_scores.values > overmerged_score, axis=0) > 1)
overmerged_ids = comp.sorting2.unit_ids[unit_index]

n = min(len(overmerged_ids), max_units)
if n > 0:
fig, axs = plt.subplots(nrows=n, figsize=figsize, squeeze=False)
axs = axs[:, 0]
for i, unit_id in enumerate(overmerged_ids[:n]):
gt_unit_indices = np.flatnonzero(comp.agreement_scores.loc[:, unit_id].values > overmerged_score)
gt_unit_ids = comp.sorting1.unit_ids[gt_unit_indices]
ax = axs[i]
ax.set_title(f"unit {unit_id} - GTids {gt_unit_ids}")

analyzer = study.get_sorting_analyzer(key)

wf_template = analyzer.get_extension("templates")
templates = wf_template.get_templates(unit_ids=gt_unit_ids)
if analyzer.sparsity is not None:
chan_mask = np.any(analyzer.sparsity.mask[gt_unit_indices, :], axis=0)
templates = templates[:, :, chan_mask]
ax.plot(templates.swapaxes(1, 2).reshape(templates.shape[0], -1).T)
ax.set_xticks([])

fig.suptitle(label)
figs.append(fig)
else:
print(key, "no overmerged")

return figs


def plot_some_over_splited(study, case_keys=None, oversplit_score=0.05, max_units=5, figsize=None):
"""
Plot some waveforms of over-splitted units.
"""
if case_keys is None:
case_keys = list(study.cases.keys())
import matplotlib.pyplot as plt

print(case_keys)
figs = []
for count, key in enumerate(case_keys):
print(key)
label = study.cases[key]["label"]
comp = study.get_result(key)["gt_comparison"]

gt_unit_indices = np.flatnonzero(np.sum(comp.agreement_scores.values > oversplit_score, axis=1) > 1)
oversplit_ids = comp.sorting1.unit_ids[gt_unit_indices]

n = min(len(oversplit_ids), max_units)
if n > 0:
fig, axs = plt.subplots(nrows=n, figsize=figsize, squeeze=False)
axs = axs[:, 0]
for i, unit_id in enumerate(oversplit_ids[:n]):
unit_indices = np.flatnonzero(comp.agreement_scores.loc[unit_id, :].values > oversplit_score)
unit_ids = comp.sorting2.unit_ids[unit_indices]
ax = axs[i]
ax.set_title(f"Gt unit {unit_id} - unit_ids: {unit_ids}")

results = study.get_result(key)
if "clustering_templates" in results:
# ClusteringBenchmark has this
templates = results["clustering_templates"]
elif "sorter_analyzer" in results:
# SorterBenchmark has this
templates = results["sorter_analyzer"].get_extension("templates").get_data(outputs="Templates")
else:
raise ValueError("This benchmark do not have templates computed")

template_arrays = templates.get_dense_templates()[unit_indices, :, :]
if templates.sparsity is not None:
chan_mask = np.any(templates.sparsity.mask[gt_unit_indices, :], axis=0)
template_arrays = template_arrays[:, :, chan_mask]

ax.plot(template_arrays.swapaxes(1, 2).reshape(template_arrays.shape[0], -1).T)
ax.set_xticks([])

fig.suptitle(label)
figs.append(fig)
else:
print(key, "no over splited")

return figs
Loading