Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/spikeinterface/widgets/unit_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
**unitwaveformdensitymapwidget_kwargs,
)
col_counter += 1
ax_waveform_density.set_xlabel(None)
ax_waveform_density.set_ylabel(None)

if sorting_analyzer.has_extension("correlograms"):
Expand Down
64 changes: 26 additions & 38 deletions src/spikeinterface/widgets/unit_waveforms_density_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ def __init__(
templates = ext_templates.get_templates(unit_ids=unit_ids)
bin_min = np.min(templates) * 1.3
bin_max = np.max(templates) * 1.3
bin_size = (bin_max - bin_min) / 100
bins = np.arange(bin_min, bin_max, bin_size)
num_bins = 100
bins = np.linspace(bin_min, bin_max, num_bins + 1)

# 2d histograms
if same_axis:
Expand Down Expand Up @@ -121,14 +121,9 @@ def __init__(
wfs = wfs_

# make histogram density
wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1)
hist2d = np.zeros((wfs_flat.shape[1], bins.size))
indexes0 = np.arange(wfs_flat.shape[1])

wf_bined = np.floor((wfs_flat - bin_min) / bin_size).astype("int32")
wf_bined = wf_bined.clip(0, bins.size - 1)
for d in wf_bined:
hist2d[indexes0, d] += 1
wfs_flat = wfs.swapaxes(1, 2).reshape(wfs.shape[0], -1) # num_spikes x (num_channels * timepoints)
hists_per_timepoint = [np.histogram(one_timepoint, bins=bins)[0] for one_timepoint in wfs_flat.T]
hist2d = np.stack(hists_per_timepoint)
Comment on lines -124 to +126
Copy link
Member

@samuelgarcia samuelgarcia Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it faster or slower ?
I am not sure to remember because I did this long time ago but in my mind this uggly trick "histogram like" was for performence. But maybe I am wrong.
lets me check it.

Copy link
Contributor Author

@ecobost ecobost Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @samuelgarcia performance happens to be slightly better for sufficiently large units, >1-2K spikes (see my comment above) but feel free to double check

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok thanks.


if same_axis:
if all_hist2d is None:
Expand Down Expand Up @@ -162,60 +157,53 @@ def __init__(
bin_min=bin_min,
bin_max=bin_max,
all_hist2d=all_hist2d,
sampling_frequency=sorting_analyzer.sampling_frequency,
templates_flat=templates_flat,
template_width=wfs.shape[1],
)

BaseWidget.__init__(self, plot_data, backend=backend, **backend_kwargs)

def plot_matplotlib(self, data_plot, **backend_kwargs):
import matplotlib.pyplot as plt
from .utils_matplotlib import make_mpl_figure

dp = to_attr(data_plot)

if backend_kwargs["axes"] is not None or backend_kwargs["ax"] is not None:
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)
else:
if dp.same_axis:
num_axes = 1
else:
num_axes = len(dp.unit_ids)
if backend_kwargs["axes"] is None and backend_kwargs["ax"] is None:
backend_kwargs["ncols"] = 1
backend_kwargs["num_axes"] = num_axes
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)
backend_kwargs["num_axes"] = 1 if dp.same_axis else len(dp.unit_ids)
self.figure, self.axes, self.ax = make_mpl_figure(**backend_kwargs)

freq_khz = dp.sampling_frequency / 1000 # samples / msec
if dp.same_axis:
ax = self.ax
hist2d = dp.all_hist2d
im = ax.imshow(
x_max = len(hist2d) / freq_khz # in milliseconds
self.ax.imshow(
hist2d.T,
interpolation="nearest",
origin="lower",
aspect="auto",
extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max),
cmap="hot",
extent=(0, x_max, dp.bin_min, dp.bin_max),
cmap="Grays",
)
else:
for unit_index, unit_id in enumerate(dp.unit_ids):
for ax, unit_id in zip(self.axes.flatten(), dp.unit_ids):
hist2d = dp.all_hist2d[unit_id]
ax = self.axes.flatten()[unit_index]
im = ax.imshow(
x_max = len(hist2d) / freq_khz # in milliseconds
ax.imshow(
hist2d.T,
interpolation="nearest",
origin="lower",
aspect="auto",
extent=(0, hist2d.shape[0], dp.bin_min, dp.bin_max),
cmap="hot",
extent=(0, x_max, dp.bin_min, dp.bin_max),
cmap="Grays",
)

for unit_index, unit_id in enumerate(dp.unit_ids):
if dp.same_axis:
ax = self.ax
else:
ax = self.axes.flatten()[unit_index]
ax = self.ax if dp.same_axis else self.axes.flatten()[unit_index]
color = dp.unit_colors[unit_id]
ax.plot(dp.templates_flat[unit_id], color=color, lw=1)
x = np.arange(len(dp.templates_flat[unit_id])) / freq_khz
ax.plot(x, dp.templates_flat[unit_id], color=color, lw=1)

# final cosmetics
for unit_index, unit_id in enumerate(dp.unit_ids):
Expand All @@ -228,11 +216,11 @@ def plot_matplotlib(self, data_plot, **backend_kwargs):
chan_inds = dp.channel_inds[unit_id]
for i, chan_ind in enumerate(chan_inds):
if i != 0:
ax.axvline(i * dp.template_width, color="w", lw=3)
ax.axvline(i * dp.template_width / freq_khz, color="w", lw=3)
channel_id = dp.channel_ids[chan_ind]
x = i * dp.template_width + dp.template_width // 2
x = (i + 0.5) * dp.template_width / freq_khz
y = (dp.bin_max + dp.bin_min) / 2.0
ax.text(x, y, f"chan_id {channel_id}", color="w", ha="center", va="center")
ax.text(x, y, f"{channel_id}", color="k", ha="center", va="center")

ax.set_xticks([])
ax.set_xlabel("Time [ms]")
ax.set_ylabel(f"unit_id {unit_id}")