Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -893,6 +893,12 @@ def _compute_and_cache_spike_vector(self) -> None:
if len(sample_indices) > 0:
sample_indices = np.concatenate(sample_indices, dtype="int64")
unit_indices = np.concatenate(unit_indices, dtype="int64")
# here we only do a sort by indices a lexsort on sample_indices
# the stable=True is equivalent to np.lexsort((unit_indices, sample_indices, ))
# because because we construct by looping on unit_ids
order = np.argsort(sample_indices, stable=True)
sample_indices = sample_indices[order]
unit_indices = unit_indices[order]
n = sample_indices.size
segment_slices[segment_index, 0] = seg_pos
segment_slices[segment_index, 1] = seg_pos + n
Expand All @@ -908,7 +914,8 @@ def _compute_and_cache_spike_vector(self) -> None:
spikes.append(spikes_in_seg)

spikes = np.concatenate(spikes)
spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))]
# the spikes are not lexsorted here because the previous loop ensure that the spike vector is constructucted alway the same way.
# spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))]

self._cached_spike_vector = spikes
self._cached_spike_vector_segment_slices = segment_slices
Expand Down
16 changes: 11 additions & 5 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ def generate_sorting(
spikes_in_seg["sample_index"] = samples
spikes_in_seg["unit_index"] = labels
spikes_in_seg["segment_index"] = segment_index
spikes.append(spikes_in_seg)

if add_spikes_on_borders:
spikes_on_borders = np.zeros(2 * num_spikes_per_border, dtype=minimum_spike_dtype)
Expand All @@ -182,10 +181,15 @@ def generate_sorting(
spikes_on_borders["sample_index"][num_spikes_per_border:] = rng.integers(
num_samples - border_size_samples, num_samples, num_spikes_per_border
)
spikes.append(spikes_on_borders)
spikes_in_seg = np.concatenate([spikes_in_seg, spikes_on_borders])
order = np.argsort(spikes_in_seg["sample_index"], stable=True)
spikes_in_seg = spikes_in_seg[order]

spikes.append(spikes_in_seg)

spikes = np.concatenate(spikes)
spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))]
# the spikes do not need a full lexsort because synthesize_poisson_spike_vector() garanty spikes to sorted by frame already
# spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))]

sorting = NumpySorting(spikes, sampling_frequency, unit_ids)

Expand Down Expand Up @@ -799,7 +803,9 @@ def synthesize_poisson_spike_vector(

# Sort globaly
spike_frames = spike_frames[:num_correct_frames]
sort_indices = np.argsort(spike_frames, kind="stable") # I profiled the different kinds, this is the fastest.
# the stable is important because this guarantees to be equivalent to
# np.lexsort((unit_indices, spike_frames, ))
sort_indices = np.argsort(spike_frames, stable=True)

unit_indices = unit_indices[sort_indices]
spike_frames = spike_frames[sort_indices]
Expand Down Expand Up @@ -888,7 +894,7 @@ def synthesize_random_firings(
times = np.concatenate(times)
labels = np.concatenate(labels)

sort_inds = np.argsort(times)
sort_inds = np.argsort(times, stable=True)
times = times[sort_inds]
labels = labels[sort_inds]

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/npzsortingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def write_sorting(sorting, save_path):
if len(spike_indexes) > 0:
spike_indexes = np.concatenate(spike_indexes)
spike_labels = np.concatenate(spike_labels)
order = np.argsort(spike_indexes)
order = np.argsort(spike_indexes, stable=True)
spike_indexes = spike_indexes[order]
spike_labels = spike_labels[order]
else:
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/numpyextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def from_samples_and_labels(samples_list, labels_list, sampling_frequency, unit_
spikes_in_seg["sample_index"] = times
spikes_in_seg["unit_index"] = unit_index
spikes_in_seg["segment_index"] = i
order = np.argsort(times)
order = np.argsort(times, stable=True)
spikes_in_seg = spikes_in_seg[order]
spikes.append(spikes_in_seg)
spikes = np.concatenate(spikes)
Expand Down Expand Up @@ -395,7 +395,7 @@ def from_unit_dict(units_dict_list, sampling_frequency) -> "NumpySorting":
sample_indices = np.concatenate(sample_indices)
unit_indices = np.concatenate(unit_indices)

order = np.argsort(sample_indices)
order = np.argsort(sample_indices, stable=True)
sample_indices = sample_indices[order]
unit_indices = unit_indices[order]

Expand Down
11 changes: 10 additions & 1 deletion src/spikeinterface/core/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,23 @@ def check_recordings_equal(


def check_sortings_equal(
SX1: BaseSorting, SX2: BaseSorting, check_annotations: bool = False, check_properties: bool = False
SX1: BaseSorting,
SX2: BaseSorting,
check_annotations: bool = False,
check_properties: bool = False,
check_exact_lexsort: bool = True,
) -> None:
assert SX1.get_num_segments() == SX2.get_num_segments()

max_spike_index = SX1.to_spike_vector()["sample_index"].max()

s1 = SX1.to_spike_vector()
s2 = SX2.to_spike_vector()
if not check_exact_lexsort:
# 2 sorting can be equal even if the internal lexsort is not the same.
# spiketrains still wiwll be the same per units
s1 = s1[np.lexsort((s1["unit_index"], s1["sample_index"], s1["segment_index"]))]
s2 = s2[np.lexsort((s2["unit_index"], s2["sample_index"], s2["segment_index"]))]
assert_array_equal(s1, s2)

for start_frame, end_frame in [
Expand Down
13 changes: 10 additions & 3 deletions src/spikeinterface/core/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,15 @@
"""

from typing import Sequence

import numpy as np

from spikeinterface.core import (
concatenate_recordings,
generate_ground_truth_recording,
generate_recording,
)
from spikeinterface.core.base import BaseExtractor
from spikeinterface.core import generate_recording, generate_ground_truth_recording, concatenate_recordings


class DummyDictExtractor(BaseExtractor):
Expand All @@ -23,7 +29,8 @@ def make_nested_extractors(extractor):
[extractor_with_parent_list, extractor_with_parent_list]
)
extractor_with_parent_dict = DummyDictExtractor(
main_ids=extractor._main_ids, base_dicts=dict(a=extractor, b=extractor, c=extractor)
main_ids=extractor._main_ids,
base_dicts=dict(a=extractor, b=extractor, c=extractor),
)
return (
extractor_wih_parent,
Expand Down Expand Up @@ -143,4 +150,4 @@ def test_setting_properties_with_custom_missing_value():

if __name__ == "__main__":
test_check_if_memory_serializable()
test_check_if_serializable()
# test_check_if_serializable()
4 changes: 2 additions & 2 deletions src/spikeinterface/core/tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,8 +655,8 @@ def test_synthesize_random_firings_length():
# test_generate_recording()
# test_generate_single_fake_waveform()
# test_transformsorting()
test_generate_unit_locations()
# test_generate_unit_locations()
# test_generate_templates()
# test_inject_templates()
# test_generate_ground_truth_recording()
# test_generate_sorting_with_spikes_on_borders()
test_generate_sorting_with_spikes_on_borders()
10 changes: 8 additions & 2 deletions src/spikeinterface/core/tests/test_sorting_folder.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_NpzFolderSorting(create_cache_folder):

sorting_loaded = NpzFolderSorting(folder)
check_sortings_equal(sorting_loaded, sorting)

assert np.array_equal(sorting_loaded.unit_ids, sorting.unit_ids)
assert np.array_equal(
sorting_loaded.to_spike_vector(),
Expand All @@ -47,5 +48,10 @@ def test_NpzFolderSorting(create_cache_folder):


if __name__ == "__main__":
test_NumpyFolderSorting()
test_NpzFolderSorting()
import tempfile
from pathlib import Path

cache_folder = Path(tempfile.mkdtemp())

test_NumpyFolderSorting(cache_folder)
test_NpzFolderSorting(cache_folder)
30 changes: 17 additions & 13 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,17 +210,20 @@ def test_create_by_dict():
Interally, this aggregates the dicts of recordings and sortings. This test checks that the
unit structure is maintained from the dicts to the analyzer. Then checks that the function
fails if the dict keys are different for the recordings and the sortings.

Note, in this tests sparse is False because units are randomlly assign to differents of the
recording and they can have no channels
"""

rec, sort = generate_ground_truth_recording(num_channels=6)
rec, sort = generate_ground_truth_recording(num_channels=6, seed=2205)

rec.set_property(key="group", values=[1, 2, 1, 1, 2, 2])
sort.set_property(key="group", values=[2, 2, 2, 1, 2, 2, 2, 1, 2, 1])

unit_ids = sort.unit_ids
split_sort = sort.split_by("group")
split_rec = rec.split_by("group")
analyzer = create_sorting_analyzer(split_sort, split_rec)
analyzer = create_sorting_analyzer(split_sort, split_rec, sparse=False)
analyzer_unit_ids = analyzer.unit_ids

assert set(analyzer.unit_ids) == set(sort.unit_ids)
Expand All @@ -236,15 +239,15 @@ def test_create_by_dict():
}

with pytest.raises(ValueError):
analyzer = create_sorting_analyzer(split_sort_bad_keys, rec.split_by("group"))
analyzer = create_sorting_analyzer(split_sort_bad_keys, rec.split_by("group"), sparse=False)

# make a dict of sortings, in a different order than the recording. This should
# still work
split_sort_different_order = {
2: sort.select_units(unit_ids=unit_ids[sort.get_property("group") == 2]),
1: sort.select_units(unit_ids=unit_ids[sort.get_property("group") == 1]),
}
combined_analyzer = create_sorting_analyzer(split_sort_different_order, rec.split_by("group"))
combined_analyzer = create_sorting_analyzer(split_sort_different_order, rec.split_by("group"), sparse=False)
assert np.all(sort.get_unit_spike_train(unit_id="5") == combined_analyzer.sorting.get_unit_spike_train(unit_id="5"))


Expand Down Expand Up @@ -715,12 +718,13 @@ def test_runtime_dependencies(dataset):


if __name__ == "__main__":
tmp_path = Path("test_SortingAnalyzer")
dataset = get_dataset()
test_SortingAnalyzer_memory(tmp_path, dataset)
test_SortingAnalyzer_binary_folder(tmp_path, dataset)
test_SortingAnalyzer_zarr(tmp_path, dataset)
test_SortingAnalyzer_tmp_recording(dataset)
test_extension()
test_extension_params()
test_runtime_dependencies()
# tmp_path = Path("test_SortingAnalyzer")
# dataset = get_dataset()
# test_SortingAnalyzer_memory(tmp_path, dataset)
# test_SortingAnalyzer_binary_folder(tmp_path, dataset)
# test_SortingAnalyzer_zarr(tmp_path, dataset)
# test_SortingAnalyzer_tmp_recording(dataset)
# test_extension()
# test_extension_params()
# test_runtime_dependencies()
test_create_by_dict()
5 changes: 4 additions & 1 deletion src/spikeinterface/core/zarrextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,10 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None,
spikes["unit_index"] = spikes_group["unit_index"][:]
for i, (start, end) in enumerate(segment_slices_list):
spikes["segment_index"][start:end] = i
spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))]
# we do not need to lexsort at init (very high cost) because there already sorted by frame before to be saved.
# In version 0.104.X this was fully lexsorted, but we don't need it anymore because it's only important in the context of SpikeVectorBased extensions in the SortingAnalyzer, which stores its own copy of the Sorting object. This makes the extension data and the spike vector always matching their order.
# spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))]

self._cached_spike_vector = spikes

for segment_index in range(num_segments):
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/extractors/mdaextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def write_sorting(sorting, save_path, write_primary_channels=False):
all_times = _concatenate(times_list)
all_labels = _concatenate(labels_list)
all_primary_channels = _concatenate(primary_channels_list)
sort_inds = np.argsort(all_times)
sort_inds = np.argsort(all_times, stable=True)
all_times = all_times[sort_inds]
all_labels = all_labels[sort_inds]
all_primary_channels = all_primary_channels[sort_inds]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def test_alf_sorting_extractor():
spike_times.append(st)
spike_clusters.append(st * 0 + i)
spike_times = np.concatenate(spike_times)
ordre = np.argsort(spike_times)
spike_times = spike_times[ordre]
spike_clusters = np.concatenate(spike_clusters)[ordre]
order = np.argsort(spike_times, stable=True)
spike_times = spike_times[order]
spike_clusters = np.concatenate(spike_clusters)[order]

with tempfile.TemporaryDirectory() as td:
folder_path = Path(td)
Expand Down
5 changes: 4 additions & 1 deletion src/spikeinterface/extractors/tests/test_mdaextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,7 @@ def test_mda_extractors(create_cache_folder):


if __name__ == "__main__":
test_mda_extractors()
import tempfile

cache_folder = Path(tempfile.mkdtemp())
test_mda_extractors(cache_folder)
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def _sorting_violation():
spike_times = np.concatenate(trains)
spike_labels = np.concatenate(labels)

order = np.argsort(spike_times)
order = np.argsort(spike_times, stable=True)
max_num_samples = np.floor(max_time * sampling_frequency) - 1
indexes = np.arange(0, max_time + 1, 1 / sampling_frequency)
spike_times = np.searchsorted(indexes, spike_times[order], side="left")
Expand Down
Loading