diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index fc8373cdfb..d4c124f1bc 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -893,6 +893,12 @@ def _compute_and_cache_spike_vector(self) -> None: if len(sample_indices) > 0: sample_indices = np.concatenate(sample_indices, dtype="int64") unit_indices = np.concatenate(unit_indices, dtype="int64") + # here we only do a sort by indices a lexsort on sample_indices + # the stable=True is equivalent to np.lexsort((unit_indices, sample_indices, )) + # because because we construct by looping on unit_ids + order = np.argsort(sample_indices, stable=True) + sample_indices = sample_indices[order] + unit_indices = unit_indices[order] n = sample_indices.size segment_slices[segment_index, 0] = seg_pos segment_slices[segment_index, 1] = seg_pos + n @@ -908,7 +914,8 @@ def _compute_and_cache_spike_vector(self) -> None: spikes.append(spikes_in_seg) spikes = np.concatenate(spikes) - spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] + # the spikes are not lexsorted here because the previous loop ensure that the spike vector is constructucted alway the same way. + # spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] self._cached_spike_vector = spikes self._cached_spike_vector_segment_slices = segment_slices diff --git a/src/spikeinterface/core/generate.py b/src/spikeinterface/core/generate.py index 4fa68ebec0..ef023d04b9 100644 --- a/src/spikeinterface/core/generate.py +++ b/src/spikeinterface/core/generate.py @@ -168,7 +168,6 @@ def generate_sorting( spikes_in_seg["sample_index"] = samples spikes_in_seg["unit_index"] = labels spikes_in_seg["segment_index"] = segment_index - spikes.append(spikes_in_seg) if add_spikes_on_borders: spikes_on_borders = np.zeros(2 * num_spikes_per_border, dtype=minimum_spike_dtype) @@ -182,10 +181,15 @@ def generate_sorting( spikes_on_borders["sample_index"][num_spikes_per_border:] = rng.integers( num_samples - border_size_samples, num_samples, num_spikes_per_border ) - spikes.append(spikes_on_borders) + spikes_in_seg = np.concatenate([spikes_in_seg, spikes_on_borders]) + order = np.argsort(spikes_in_seg["sample_index"], stable=True) + spikes_in_seg = spikes_in_seg[order] + + spikes.append(spikes_in_seg) spikes = np.concatenate(spikes) - spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] + # the spikes do not need a full lexsort because synthesize_poisson_spike_vector() garanty spikes to sorted by frame already + # spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] sorting = NumpySorting(spikes, sampling_frequency, unit_ids) @@ -799,7 +803,9 @@ def synthesize_poisson_spike_vector( # Sort globaly spike_frames = spike_frames[:num_correct_frames] - sort_indices = np.argsort(spike_frames, kind="stable") # I profiled the different kinds, this is the fastest. + # the stable is important because this guarantees to be equivalent to + # np.lexsort((unit_indices, spike_frames, )) + sort_indices = np.argsort(spike_frames, stable=True) unit_indices = unit_indices[sort_indices] spike_frames = spike_frames[sort_indices] @@ -888,7 +894,7 @@ def synthesize_random_firings( times = np.concatenate(times) labels = np.concatenate(labels) - sort_inds = np.argsort(times) + sort_inds = np.argsort(times, stable=True) times = times[sort_inds] labels = labels[sort_inds] diff --git a/src/spikeinterface/core/npzsortingextractor.py b/src/spikeinterface/core/npzsortingextractor.py index 3d65d32744..af608d3fb7 100644 --- a/src/spikeinterface/core/npzsortingextractor.py +++ b/src/spikeinterface/core/npzsortingextractor.py @@ -53,7 +53,7 @@ def write_sorting(sorting, save_path): if len(spike_indexes) > 0: spike_indexes = np.concatenate(spike_indexes) spike_labels = np.concatenate(spike_labels) - order = np.argsort(spike_indexes) + order = np.argsort(spike_indexes, stable=True) spike_indexes = spike_indexes[order] spike_labels = spike_labels[order] else: diff --git a/src/spikeinterface/core/numpyextractors.py b/src/spikeinterface/core/numpyextractors.py index 31a3a8831d..b332822259 100644 --- a/src/spikeinterface/core/numpyextractors.py +++ b/src/spikeinterface/core/numpyextractors.py @@ -329,7 +329,7 @@ def from_samples_and_labels(samples_list, labels_list, sampling_frequency, unit_ spikes_in_seg["sample_index"] = times spikes_in_seg["unit_index"] = unit_index spikes_in_seg["segment_index"] = i - order = np.argsort(times) + order = np.argsort(times, stable=True) spikes_in_seg = spikes_in_seg[order] spikes.append(spikes_in_seg) spikes = np.concatenate(spikes) @@ -395,7 +395,7 @@ def from_unit_dict(units_dict_list, sampling_frequency) -> "NumpySorting": sample_indices = np.concatenate(sample_indices) unit_indices = np.concatenate(unit_indices) - order = np.argsort(sample_indices) + order = np.argsort(sample_indices, stable=True) sample_indices = sample_indices[order] unit_indices = unit_indices[order] diff --git a/src/spikeinterface/core/testing.py b/src/spikeinterface/core/testing.py index 2b5a7c5157..67516cbf68 100644 --- a/src/spikeinterface/core/testing.py +++ b/src/spikeinterface/core/testing.py @@ -104,7 +104,11 @@ def check_recordings_equal( def check_sortings_equal( - SX1: BaseSorting, SX2: BaseSorting, check_annotations: bool = False, check_properties: bool = False + SX1: BaseSorting, + SX2: BaseSorting, + check_annotations: bool = False, + check_properties: bool = False, + check_exact_lexsort: bool = True, ) -> None: assert SX1.get_num_segments() == SX2.get_num_segments() @@ -112,6 +116,11 @@ def check_sortings_equal( s1 = SX1.to_spike_vector() s2 = SX2.to_spike_vector() + if not check_exact_lexsort: + # 2 sorting can be equal even if the internal lexsort is not the same. + # spiketrains still wiwll be the same per units + s1 = s1[np.lexsort((s1["unit_index"], s1["sample_index"], s1["segment_index"]))] + s2 = s2[np.lexsort((s2["unit_index"], s2["sample_index"], s2["segment_index"]))] assert_array_equal(s1, s2) for start_frame, end_frame in [ diff --git a/src/spikeinterface/core/tests/test_base.py b/src/spikeinterface/core/tests/test_base.py index e3c3149ab0..e26612c988 100644 --- a/src/spikeinterface/core/tests/test_base.py +++ b/src/spikeinterface/core/tests/test_base.py @@ -4,9 +4,15 @@ """ from typing import Sequence + import numpy as np + +from spikeinterface.core import ( + concatenate_recordings, + generate_ground_truth_recording, + generate_recording, +) from spikeinterface.core.base import BaseExtractor -from spikeinterface.core import generate_recording, generate_ground_truth_recording, concatenate_recordings class DummyDictExtractor(BaseExtractor): @@ -23,7 +29,8 @@ def make_nested_extractors(extractor): [extractor_with_parent_list, extractor_with_parent_list] ) extractor_with_parent_dict = DummyDictExtractor( - main_ids=extractor._main_ids, base_dicts=dict(a=extractor, b=extractor, c=extractor) + main_ids=extractor._main_ids, + base_dicts=dict(a=extractor, b=extractor, c=extractor), ) return ( extractor_wih_parent, @@ -143,4 +150,4 @@ def test_setting_properties_with_custom_missing_value(): if __name__ == "__main__": test_check_if_memory_serializable() - test_check_if_serializable() + # test_check_if_serializable() diff --git a/src/spikeinterface/core/tests/test_generate.py b/src/spikeinterface/core/tests/test_generate.py index c5a0b83f87..c5a29c62d3 100644 --- a/src/spikeinterface/core/tests/test_generate.py +++ b/src/spikeinterface/core/tests/test_generate.py @@ -655,8 +655,8 @@ def test_synthesize_random_firings_length(): # test_generate_recording() # test_generate_single_fake_waveform() # test_transformsorting() - test_generate_unit_locations() + # test_generate_unit_locations() # test_generate_templates() # test_inject_templates() # test_generate_ground_truth_recording() - # test_generate_sorting_with_spikes_on_borders() + test_generate_sorting_with_spikes_on_borders() diff --git a/src/spikeinterface/core/tests/test_sorting_folder.py b/src/spikeinterface/core/tests/test_sorting_folder.py index a285ae8d29..8d78a89ebd 100644 --- a/src/spikeinterface/core/tests/test_sorting_folder.py +++ b/src/spikeinterface/core/tests/test_sorting_folder.py @@ -39,6 +39,7 @@ def test_NpzFolderSorting(create_cache_folder): sorting_loaded = NpzFolderSorting(folder) check_sortings_equal(sorting_loaded, sorting) + assert np.array_equal(sorting_loaded.unit_ids, sorting.unit_ids) assert np.array_equal( sorting_loaded.to_spike_vector(), @@ -47,5 +48,10 @@ def test_NpzFolderSorting(create_cache_folder): if __name__ == "__main__": - test_NumpyFolderSorting() - test_NpzFolderSorting() + import tempfile + from pathlib import Path + + cache_folder = Path(tempfile.mkdtemp()) + + test_NumpyFolderSorting(cache_folder) + test_NpzFolderSorting(cache_folder) diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index a9bd71b5c0..14e62384fa 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -210,9 +210,12 @@ def test_create_by_dict(): Interally, this aggregates the dicts of recordings and sortings. This test checks that the unit structure is maintained from the dicts to the analyzer. Then checks that the function fails if the dict keys are different for the recordings and the sortings. + + Note, in this tests sparse is False because units are randomlly assign to differents of the + recording and they can have no channels """ - rec, sort = generate_ground_truth_recording(num_channels=6) + rec, sort = generate_ground_truth_recording(num_channels=6, seed=2205) rec.set_property(key="group", values=[1, 2, 1, 1, 2, 2]) sort.set_property(key="group", values=[2, 2, 2, 1, 2, 2, 2, 1, 2, 1]) @@ -220,7 +223,7 @@ def test_create_by_dict(): unit_ids = sort.unit_ids split_sort = sort.split_by("group") split_rec = rec.split_by("group") - analyzer = create_sorting_analyzer(split_sort, split_rec) + analyzer = create_sorting_analyzer(split_sort, split_rec, sparse=False) analyzer_unit_ids = analyzer.unit_ids assert set(analyzer.unit_ids) == set(sort.unit_ids) @@ -236,7 +239,7 @@ def test_create_by_dict(): } with pytest.raises(ValueError): - analyzer = create_sorting_analyzer(split_sort_bad_keys, rec.split_by("group")) + analyzer = create_sorting_analyzer(split_sort_bad_keys, rec.split_by("group"), sparse=False) # make a dict of sortings, in a different order than the recording. This should # still work @@ -244,7 +247,7 @@ def test_create_by_dict(): 2: sort.select_units(unit_ids=unit_ids[sort.get_property("group") == 2]), 1: sort.select_units(unit_ids=unit_ids[sort.get_property("group") == 1]), } - combined_analyzer = create_sorting_analyzer(split_sort_different_order, rec.split_by("group")) + combined_analyzer = create_sorting_analyzer(split_sort_different_order, rec.split_by("group"), sparse=False) assert np.all(sort.get_unit_spike_train(unit_id="5") == combined_analyzer.sorting.get_unit_spike_train(unit_id="5")) @@ -715,12 +718,13 @@ def test_runtime_dependencies(dataset): if __name__ == "__main__": - tmp_path = Path("test_SortingAnalyzer") - dataset = get_dataset() - test_SortingAnalyzer_memory(tmp_path, dataset) - test_SortingAnalyzer_binary_folder(tmp_path, dataset) - test_SortingAnalyzer_zarr(tmp_path, dataset) - test_SortingAnalyzer_tmp_recording(dataset) - test_extension() - test_extension_params() - test_runtime_dependencies() + # tmp_path = Path("test_SortingAnalyzer") + # dataset = get_dataset() + # test_SortingAnalyzer_memory(tmp_path, dataset) + # test_SortingAnalyzer_binary_folder(tmp_path, dataset) + # test_SortingAnalyzer_zarr(tmp_path, dataset) + # test_SortingAnalyzer_tmp_recording(dataset) + # test_extension() + # test_extension_params() + # test_runtime_dependencies() + test_create_by_dict() diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index bbc797c693..885cba69f7 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -294,7 +294,10 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None, spikes["unit_index"] = spikes_group["unit_index"][:] for i, (start, end) in enumerate(segment_slices_list): spikes["segment_index"][start:end] = i - spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] + # we do not need to lexsort at init (very high cost) because there already sorted by frame before to be saved. + # In version 0.104.X this was fully lexsorted, but we don't need it anymore because it's only important in the context of SpikeVectorBased extensions in the SortingAnalyzer, which stores its own copy of the Sorting object. This makes the extension data and the spike vector always matching their order. + # spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] + self._cached_spike_vector = spikes for segment_index in range(num_segments): diff --git a/src/spikeinterface/extractors/mdaextractors.py b/src/spikeinterface/extractors/mdaextractors.py index 7a7bdd45a6..12a7474df7 100644 --- a/src/spikeinterface/extractors/mdaextractors.py +++ b/src/spikeinterface/extractors/mdaextractors.py @@ -252,7 +252,7 @@ def write_sorting(sorting, save_path, write_primary_channels=False): all_times = _concatenate(times_list) all_labels = _concatenate(labels_list) all_primary_channels = _concatenate(primary_channels_list) - sort_inds = np.argsort(all_times) + sort_inds = np.argsort(all_times, stable=True) all_times = all_times[sort_inds] all_labels = all_labels[sort_inds] all_primary_channels = all_primary_channels[sort_inds] diff --git a/src/spikeinterface/extractors/tests/test_alfsortingextractor.py b/src/spikeinterface/extractors/tests/test_alfsortingextractor.py index a873965ddd..6ed5220f1b 100644 --- a/src/spikeinterface/extractors/tests/test_alfsortingextractor.py +++ b/src/spikeinterface/extractors/tests/test_alfsortingextractor.py @@ -19,9 +19,9 @@ def test_alf_sorting_extractor(): spike_times.append(st) spike_clusters.append(st * 0 + i) spike_times = np.concatenate(spike_times) - ordre = np.argsort(spike_times) - spike_times = spike_times[ordre] - spike_clusters = np.concatenate(spike_clusters)[ordre] + order = np.argsort(spike_times, stable=True) + spike_times = spike_times[order] + spike_clusters = np.concatenate(spike_clusters)[order] with tempfile.TemporaryDirectory() as td: folder_path = Path(td) diff --git a/src/spikeinterface/extractors/tests/test_mdaextractors.py b/src/spikeinterface/extractors/tests/test_mdaextractors.py index 1ed930613f..2dad997910 100644 --- a/src/spikeinterface/extractors/tests/test_mdaextractors.py +++ b/src/spikeinterface/extractors/tests/test_mdaextractors.py @@ -42,4 +42,7 @@ def test_mda_extractors(create_cache_folder): if __name__ == "__main__": - test_mda_extractors() + import tempfile + + cache_folder = Path(tempfile.mkdtemp()) + test_mda_extractors(cache_folder) diff --git a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py index e267b176ce..6350097d08 100644 --- a/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py +++ b/src/spikeinterface/metrics/quality/tests/test_metrics_functions.py @@ -60,7 +60,7 @@ def _sorting_violation(): spike_times = np.concatenate(trains) spike_labels = np.concatenate(labels) - order = np.argsort(spike_times) + order = np.argsort(spike_times, stable=True) max_num_samples = np.floor(max_time * sampling_frequency) - 1 indexes = np.arange(0, max_time + 1, 1 / sampling_frequency) spike_times = np.searchsorted(indexes, spike_times[order], side="left")