Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/spikeinterface/exporters/to_pynapple.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
def to_pynapple_tsgroup(
sorting_analyzer_or_sorting: SortingAnalyzer | BaseSorting,
attach_unit_metadata=True,
attach_unit_properties=True,
segment_index=None,
):
"""
Expand All @@ -21,6 +22,8 @@ def to_pynapple_tsgroup(
If True, any relevant available metadata is attached to the TsGroup. Will attach
`unit_locations`, `quality_metrics` and `template_metrics` if computed. If False,
no metadata is included.
attach_unit_properties : bool, default: False
If True, attach properties of the sorting.
segment_index : int | None, default: None
The segment index. Can be None if mono-segment sorting.

Expand Down Expand Up @@ -59,14 +62,13 @@ def to_pynapple_tsgroup(
for unit_id_int, unit_id in zip(unit_ids_ints, unit_ids)
}

metadata_list = []
metadata_list = [] # init list to collect metadata dataframes

if not unit_ids_castable:
metadata_list.append(pd.DataFrame(unit_ids, columns=["unit_id"]))

# Look for good metadata to add, if there is a sorting analyzer
if attach_unit_metadata and isinstance(sorting_analyzer_or_sorting, SortingAnalyzer):

metadata_list = []
if (unit_locations := sorting_analyzer_or_sorting.get_extension("unit_locations")) is not None:
array_of_unit_locations = unit_locations.get_data()
n_dims = np.shape(sorting_analyzer_or_sorting.get_extension("unit_locations").get_data())[1]
Expand All @@ -79,6 +81,16 @@ def to_pynapple_tsgroup(
if (template_metrics := sorting_analyzer_or_sorting.get_extension("template_metrics")) is not None:
metadata_list.append(template_metrics.get_data())

# attach unit properties from sorting
if attach_unit_properties:
property_df = pd.DataFrame(index=unit_ids)
property_keys = sorting.get_property_keys() # get property keys of sorting
if len(property_keys):
for property_key in property_keys: # loop through sorting's properties
property_data = sorting.get_property(property_key)
property_df[property_key] = list(property_data)
metadata_list.append(property_df)

if len(metadata_list) > 0:
metadata = pd.concat(metadata_list, axis=1)
metadata.index = unit_ids_ints
Expand Down