From 3dc57290dbde0aeaa5048f2301ee75015a93fe26 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 29 Dec 2025 15:43:44 +0100 Subject: [PATCH 01/12] Test IBL extractors tests failing for PI update --- src/spikeinterface/extractors/tests/test_iblextractors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 972a8e7bb0..56d01e38cf 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -76,8 +76,8 @@ def test_offsets(self): def test_probe_representation(self): probe = self.recording.get_probe() - expected_probe_representation = "Probe - 384ch - 1shanks" - assert repr(probe) == expected_probe_representation + expected_probe_representation = "Probe - 384ch" + assert expected_probe_representation in repr(probe) def test_property_keys(self): expected_property_keys = [ From 359b68bd234fa06b6e07b8037b72ae64a7801480 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 9 Apr 2026 17:38:35 +0200 Subject: [PATCH 02/12] Implement get_unit_spike_trains function --- src/spikeinterface/core/basesorting.py | 208 +++++++++++++++++- .../core/unitsselectionsorting.py | 23 +- 2 files changed, 220 insertions(+), 11 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index cb68f3d455..4e0e1fed52 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -180,6 +180,7 @@ def get_unit_spike_train( segment_index=segment_index, start_time=start_time, end_time=end_time, + use_cache=use_cache, ) segment_index = self._check_segment_index(segment_index) @@ -212,6 +213,7 @@ def get_unit_spike_train_in_seconds( segment_index: int | None = None, start_time: float | None = None, end_time: float | None = None, + use_cache: bool = True, ) -> np.ndarray: """ Get spike train for a unit in seconds. @@ -236,6 +238,10 @@ def get_unit_spike_train_in_seconds( The start time in seconds for spike train extraction end_time : float or None, default: None The end time in seconds for spike train extraction + return_times : bool, default: False + If True, returns spike times in seconds instead of frames + use_cache : bool, default: True + If True, then precompute (or use) the to_reordered_spike_vector using Returns ------- @@ -258,7 +264,7 @@ def get_unit_spike_train_in_seconds( start_frame=start_frame, end_frame=end_frame, return_times=False, - use_cache=True, + use_cache=use_cache, ) spike_times = self.sample_index_to_time(spike_frames, segment_index=segment_index) @@ -288,13 +294,170 @@ def get_unit_spike_train_in_seconds( start_frame=start_frame, end_frame=end_frame, return_times=False, - use_cache=True, + use_cache=use_cache, ) t_start = segment._t_start if segment._t_start is not None else 0 spike_times = spike_frames / self.get_sampling_frequency() return t_start + spike_times + def get_unit_spike_trains( + self, + unit_ids: np.ndarray | list, + segment_index: int | None = None, + start_frame: int | None = None, + end_frame: int | None = None, + return_times: bool = False, + use_cache: bool = True, + ) -> dict[int | str, np.ndarray]: + """Return spike trains for multiple units. + + Parameters + ---------- + unit_ids : np.ndarray | list + Unit ids to retrieve spike trains for + segment_index : int or None, default: None + The segment index to retrieve spike train from. + For multi-segment objects, it is required + start_frame : int or None, default: None + The start frame for spike train extraction + end_frame : int or None, default: None + The end frame for spike train extraction + return_times : bool, default: False + If True, returns spike times in seconds instead of frames + use_cache : bool, default: True + If True, then precompute (or use) the to_reordered_spike_vector using + + Returns + ------- + dict[int | str, np.ndarray] + A dictionary where keys are unit ids and values are spike trains (arrays of spike times or frames) + """ + if return_times: + start_time = ( + self.sample_index_to_time(start_frame, segment_index=segment_index) if start_frame is not None else None + ) + end_time = ( + self.sample_index_to_time(end_frame, segment_index=segment_index) if end_frame is not None else None + ) + + return self.get_unit_spike_trains_in_seconds( + unit_ids=unit_ids, + segment_index=segment_index, + start_time=start_time, + end_time=end_time, + use_cache=use_cache, + ) + + segment_index = self._check_segment_index(segment_index) + if use_cache: + # TODO: speed things up + ordered_spike_vector, slices = self.to_reordered_spike_vector( + lexsort=("sample_index", "segment_index", "unit_index"), + return_order=False, + return_slices=True, + ) + unit_indices = self.ids_to_indices(unit_ids) + spike_trains = {} + for unit_index, unit_id in zip(unit_indices, unit_ids): + sl0, sl1 = slices[unit_index, segment_index, :] + spikes = ordered_spike_vector[sl0:sl1] + spike_frames = spikes["sample_index"] + if start_frame is not None: + start = np.searchsorted(spike_frames, start_frame) + spike_frames = spike_frames[start:] + if end_frame is not None: + end = np.searchsorted(spike_frames, end_frame) + spike_frames = spike_frames[:end] + spike_trains[unit_id] = spike_frames + else: + spike_trains = segment.get_unit_spike_trains( + unit_ids=unit_ids, start_frame=start_frame, end_frame=end_frame, return_times=return_times + ) + + def get_unit_spike_trains_in_seconds( + self, + unit_ids: np.ndarray | list, + segment_index: int | None = None, + start_time: float | None = None, + end_time: float | None = None, + return_times: bool = False, + use_cache: bool = True, + ) -> dict[int | str, np.ndarray]: + """Return spike trains for multiple units in seconds + + Parameters + ---------- + unit_ids : np.ndarray | list + Unit ids to retrieve spike trains for + segment_index : int or None, default: None + The segment index to retrieve spike train from. + For multi-segment objects, it is required + start_time : float or None, default: None + The start time in seconds for spike train extraction + end_time : float or None, default: None + The end time in seconds for spike train extraction + return_times : bool, default: False + If True, returns spike times in seconds instead of frames + use_cache : bool, default: True + If True, then precompute (or use) the to_reordered_spike_vector using + + Returns + ------- + dict[int | str, np.ndarray] + A dictionary where keys are unit ids and values are spike trains (arrays of spike times in seconds) + """ + segment_index = self._check_segment_index(segment_index) + segment = self.segments[segment_index] + + # If sorting has a registered recording, get the frames and get the times from the recording + # Note that this take into account the segment start time of the recording + spike_times = {} + if self.has_recording(): + # Get all the spike times and then slice them + start_frame = None + end_frame = None + spike_train_frames = self.get_unit_spike_trains( + unit_ids=unit_ids, + segment_index=segment_index, + start_frame=start_frame, + end_frame=end_frame, + return_times=False, + use_cache=use_cache, + ) + + for unit_id in unit_ids: + spike_frames = self.sample_index_to_time(spike_train_frames[unit_id], segment_index=segment_index) + + # Filter to return only the spikes within the specified time range + if start_time is not None: + spike_frames = spike_frames[spike_frames >= start_time] + if end_time is not None: + spike_frames = spike_frames[spike_frames <= end_time] + + spike_times[unit_id] = spike_frames + + return spike_times + + # If no recording attached and all back to frame-based conversion + # Get spike train in frames and convert to times using traditional method + start_frame = self.time_to_sample_index(start_time, segment_index=segment_index) if start_time else None + end_frame = self.time_to_sample_index(end_time, segment_index=segment_index) if end_time else None + + spike_frames = self.get_unit_spike_trains( + unit_ids=unit_ids, + segment_index=segment_index, + start_frame=start_frame, + end_frame=end_frame, + return_times=False, + use_cache=use_cache, + ) + for unit_id in unit_ids: + spike_frames = spike_frames[unit_id] + t_start = segment._t_start if segment._t_start is not None else 0 + spike_times[unit_id] = spike_frames / self.get_sampling_frequency() + return t_start + spike_times + def register_recording(self, recording, check_spike_frames: bool = True): """ Register a recording to the sorting. If the sorting and recording both contain @@ -978,7 +1141,7 @@ def to_reordered_spike_vector( s1 = seg_slices[segment_index + 1] slices[unit_index, segment_index, :] = [u0 + s0, u0 + s1] - elif ("sample_index", "unit_index", "segment_index"): + elif lexsort == ("sample_index", "unit_index", "segment_index"): slices = np.zeros((num_segments, num_units, 2), dtype=np.int64) seg_slices = np.searchsorted(ordered_spikes["segment_index"], np.arange(num_segments + 1), side="left") for segment_index in range(self.get_num_segments()): @@ -1083,7 +1246,7 @@ def __init__(self, t_start=None): def get_unit_spike_train( self, - unit_id, + unit_id: int | str, start_frame: int | None = None, end_frame: int | None = None, ) -> np.ndarray: @@ -1091,18 +1254,51 @@ def get_unit_spike_train( Parameters ---------- - unit_id + unit_id : int | str + The unit id for which to get the spike train. start_frame : int, default: None + The start frame for the spike train. If None, it is set to the beginning of the segment. end_frame : int, default: None + The end frame for the spike train. If None, it is set to the end of the segment. + Returns ------- np.ndarray - + The spike train for the given unit id and time interval. """ # must be implemented in subclass raise NotImplementedError + def get_unit_spike_trains( + self, + unit_ids: np.ndarray | list, + start_frame: int | None = None, + end_frame: int | None = None, + ) -> dict[int | str, np.ndarray]: + """Get the spike trains for several units. + Can be implemented in subclass for performance but the default implementation is to call + get_unit_spike_train for each unit_id. + + Parameters + ---------- + unit_ids : numpy.array or list + The unit ids for which to get the spike trains. + start_frame : int, default: None + The start frame for the spike trains. If None, it is set to the beginning of the segment. + end_frame : int, default: None + The end frame for the spike trains. If None, it is set to the end of the segment. + + Returns + ------- + dict[int | str, np.ndarray] + A dictionary where keys are unit_ids and values are the corresponding spike trains. + """ + spike_trains = {} + for unit_id in unit_ids: + spike_trains[unit_id] = self.get_unit_spike_train(unit_id, start_frame=start_frame, end_frame=end_frame) + return spike_trains + class SpikeVectorSortingSegment(BaseSortingSegment): """ diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index 59356db976..4e1e81b81f 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -59,11 +59,15 @@ def _compute_and_cache_spike_vector(self) -> None: all_old_unit_ids=self._parent_sorting.unit_ids, all_new_unit_ids=self._unit_ids, ) - # lexsort by segment_index, sample_index, unit_index - sort_indices = np.lexsort( - (spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"]) - ) - self._cached_spike_vector = spike_vector[sort_indices] + # lexsort by segment_index, sample_index, unit_index, only if needed + # (remapping can change the order of unit indices) + if np.diff(self.ids_to_indixes(self._unit_ids)).min() < 0: + sort_indices = np.lexsort( + (spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"]) + ) + self._cached_spike_vector = spike_vector[sort_indices] + else: + self._cached_spike_vector = spike_vector class UnitsSelectionSortingSegment(BaseSortingSegment): @@ -81,3 +85,12 @@ def get_unit_spike_train( unit_id_parent = self._ids_conversion[unit_id] times = self._parent_segment.get_unit_spike_train(unit_id_parent, start_frame, end_frame) return times + + def get_unit_spike_trains( + self, + unit_ids, + start_frame: int | None = None, + end_frame: int | None = None, + ) -> dict: + unit_ids_parent = [self._ids_conversion[unit_id] for unit_id in unit_ids] + return self._parent_segment.get_unit_spike_trains(unit_ids_parent, start_frame, end_frame) From 85220e5b914534e67298d6b655e6217b0c7a6dac Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 9 Apr 2026 17:58:04 +0200 Subject: [PATCH 03/12] oups --- src/spikeinterface/core/unitsselectionsorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index 4e1e81b81f..dbcf2ee7ce 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -61,7 +61,7 @@ def _compute_and_cache_spike_vector(self) -> None: ) # lexsort by segment_index, sample_index, unit_index, only if needed # (remapping can change the order of unit indices) - if np.diff(self.ids_to_indixes(self._unit_ids)).min() < 0: + if np.diff(self.ids_to_indices(self._unit_ids)).min() < 0: sort_indices = np.lexsort( (spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"]) ) From 0efad8312becfbf9edc8b6d2c3d630524491cee8 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 10 Apr 2026 09:33:32 +0200 Subject: [PATCH 04/12] Fix tests --- src/spikeinterface/core/unitsselectionsorting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index dbcf2ee7ce..d8d2d92afb 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -61,7 +61,7 @@ def _compute_and_cache_spike_vector(self) -> None: ) # lexsort by segment_index, sample_index, unit_index, only if needed # (remapping can change the order of unit indices) - if np.diff(self.ids_to_indices(self._unit_ids)).min() < 0: + if len(self._renamed_unit_ids) > 1 and np.diff(self.ids_to_indices(self._renamed_unit_ids)).min() < 0: sort_indices = np.lexsort( (spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"]) ) From b1911bfd3be20e34775eb506b3173f18c5d69980 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 10 Apr 2026 09:43:00 +0200 Subject: [PATCH 05/12] add tests and fixes --- src/spikeinterface/core/basesorting.py | 10 +++--- .../core/tests/test_basesorting.py | 31 +++++++++++++++++++ 2 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 4e0e1fed52..4aa266f956 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -350,6 +350,7 @@ def get_unit_spike_trains( ) segment_index = self._check_segment_index(segment_index) + segment = self.segments[segment_index] if use_cache: # TODO: speed things up ordered_spike_vector, slices = self.to_reordered_spike_vector( @@ -372,8 +373,9 @@ def get_unit_spike_trains( spike_trains[unit_id] = spike_frames else: spike_trains = segment.get_unit_spike_trains( - unit_ids=unit_ids, start_frame=start_frame, end_frame=end_frame, return_times=return_times + unit_ids=unit_ids, start_frame=start_frame, end_frame=end_frame ) + return spike_trains def get_unit_spike_trains_in_seconds( self, @@ -453,10 +455,10 @@ def get_unit_spike_trains_in_seconds( use_cache=use_cache, ) for unit_id in unit_ids: - spike_frames = spike_frames[unit_id] + spike_frames_unit = spike_frames[unit_id] t_start = segment._t_start if segment._t_start is not None else 0 - spike_times[unit_id] = spike_frames / self.get_sampling_frequency() - return t_start + spike_times + spike_times[unit_id] = spike_frames_unit / self.get_sampling_frequency() + t_start + return spike_times def register_recording(self, recording, check_spike_frames: bool = True): """ diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 6c06b212b8..3d56d3e4e5 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -310,6 +310,37 @@ def test_select_periods(): np.testing.assert_array_equal(sliced_sorting.to_spike_vector(), sliced_sorting_array.to_spike_vector()) +@pytest.mark.parametrize("use_cache", [False, True]) +def test_get_unit_spike_trains(use_cache): + sampling_frequency = 10_000.0 + duration = 1.0 + num_samples = int(sampling_frequency * duration) + num_units = 10 + sorting = generate_sorting(durations=[duration], sampling_frequency=sampling_frequency, num_units=num_units) + + all_spike_trains = sorting.get_unit_spike_trains(unit_ids=sorting.unit_ids, use_cache=use_cache) + assert isinstance(all_spike_trains, dict) + assert set(all_spike_trains.keys()) == set(sorting.unit_ids) + for unit_id in sorting.unit_ids: + spiketrain = sorting.get_unit_spike_train(segment_index=0, unit_id=unit_id, use_cache=use_cache) + assert np.array_equal(all_spike_trains[unit_id], spiketrain) + + # test with times + spike_trains_times = sorting.get_unit_spike_trains_in_seconds( + unit_ids=sorting.unit_ids, return_times=True, use_cache=use_cache + ) + assert isinstance(spike_trains_times, dict) + assert set(spike_trains_times.keys()) == set(sorting.unit_ids) + for unit_id in sorting.unit_ids: + spiketrain = sorting.get_unit_spike_train( + segment_index=0, unit_id=unit_id, use_cache=use_cache, return_times=True + ) + spiketrain_times = sorting.get_unit_spike_train_in_seconds( + segment_index=0, unit_id=unit_id, use_cache=use_cache + ) + assert np.allclose(spiketrain_times, spiketrain) + + if __name__ == "__main__": import tempfile From 0744705414ee3bf2f5263a32b486a73e0e5aa9b9 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Fri, 10 Apr 2026 14:08:27 -0500 Subject: [PATCH 06/12] Fix bugs in get_unit_spike_trains_in_seconds and segment keying - Drop unused `return_times` parameter from get_unit_spike_trains_in_seconds - Clean up stale/truncated docstrings on get_unit_spike_train_in_seconds, get_unit_spike_trains, and get_unit_spike_trains_in_seconds - Fix UnitsSelectionSortingSegment.get_unit_spike_trains to re-key the returned dict with child unit ids (was returning parent-keyed dict, breaking whenever renamed_unit_ids differ from parent ids) - Fix test_get_unit_spike_trains: drop unused return_times kwarg, remove unused local variable, fix assertion. --- src/spikeinterface/core/basesorting.py | 17 ++++++----------- .../core/tests/test_basesorting.py | 8 ++------ .../core/unitsselectionsorting.py | 3 ++- 3 files changed, 10 insertions(+), 18 deletions(-) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index 4aa266f956..26de57c24d 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -238,10 +238,8 @@ def get_unit_spike_train_in_seconds( The start time in seconds for spike train extraction end_time : float or None, default: None The end time in seconds for spike train extraction - return_times : bool, default: False - If True, returns spike times in seconds instead of frames use_cache : bool, default: True - If True, then precompute (or use) the to_reordered_spike_vector using + If True, precompute (or use) the reordered spike vector cache for fast access. Returns ------- @@ -252,7 +250,7 @@ def get_unit_spike_train_in_seconds( segment = self.segments[segment_index] # If sorting has a registered recording, get the frames and get the times from the recording - # Note that this take into account the segment start time of the recording + # Note that this takes into account the segment start time of the recording if self.has_recording(): # Get all the spike times and then slice them @@ -326,7 +324,7 @@ def get_unit_spike_trains( return_times : bool, default: False If True, returns spike times in seconds instead of frames use_cache : bool, default: True - If True, then precompute (or use) the to_reordered_spike_vector using + If True, precompute (or use) the reordered spike vector cache for fast access. Returns ------- @@ -383,10 +381,9 @@ def get_unit_spike_trains_in_seconds( segment_index: int | None = None, start_time: float | None = None, end_time: float | None = None, - return_times: bool = False, use_cache: bool = True, ) -> dict[int | str, np.ndarray]: - """Return spike trains for multiple units in seconds + """Return spike trains for multiple units in seconds. Parameters ---------- @@ -399,10 +396,8 @@ def get_unit_spike_trains_in_seconds( The start time in seconds for spike train extraction end_time : float or None, default: None The end time in seconds for spike train extraction - return_times : bool, default: False - If True, returns spike times in seconds instead of frames use_cache : bool, default: True - If True, then precompute (or use) the to_reordered_spike_vector using + If True, precompute (or use) the reordered spike vector cache for fast access. Returns ------- @@ -413,7 +408,7 @@ def get_unit_spike_trains_in_seconds( segment = self.segments[segment_index] # If sorting has a registered recording, get the frames and get the times from the recording - # Note that this take into account the segment start time of the recording + # Note that this takes into account the segment start time of the recording spike_times = {} if self.has_recording(): # Get all the spike times and then slice them diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 3d56d3e4e5..247fa671f6 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -314,7 +314,6 @@ def test_select_periods(): def test_get_unit_spike_trains(use_cache): sampling_frequency = 10_000.0 duration = 1.0 - num_samples = int(sampling_frequency * duration) num_units = 10 sorting = generate_sorting(durations=[duration], sampling_frequency=sampling_frequency, num_units=num_units) @@ -327,18 +326,15 @@ def test_get_unit_spike_trains(use_cache): # test with times spike_trains_times = sorting.get_unit_spike_trains_in_seconds( - unit_ids=sorting.unit_ids, return_times=True, use_cache=use_cache + unit_ids=sorting.unit_ids, use_cache=use_cache ) assert isinstance(spike_trains_times, dict) assert set(spike_trains_times.keys()) == set(sorting.unit_ids) for unit_id in sorting.unit_ids: - spiketrain = sorting.get_unit_spike_train( - segment_index=0, unit_id=unit_id, use_cache=use_cache, return_times=True - ) spiketrain_times = sorting.get_unit_spike_train_in_seconds( segment_index=0, unit_id=unit_id, use_cache=use_cache ) - assert np.allclose(spiketrain_times, spiketrain) + assert np.allclose(spike_trains_times[unit_id], spiketrain_times) if __name__ == "__main__": diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index d8d2d92afb..a5ae75873c 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -93,4 +93,5 @@ def get_unit_spike_trains( end_frame: int | None = None, ) -> dict: unit_ids_parent = [self._ids_conversion[unit_id] for unit_id in unit_ids] - return self._parent_segment.get_unit_spike_trains(unit_ids_parent, start_frame, end_frame) + parent_trains = self._parent_segment.get_unit_spike_trains(unit_ids_parent, start_frame, end_frame) + return {child_id: parent_trains[parent_id] for child_id, parent_id in zip(unit_ids, unit_ids_parent)} From c71550b131c521c7bad77f13cc98896f8ac6a7c3 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Fri, 10 Apr 2026 14:19:05 -0500 Subject: [PATCH 07/12] Fix lexsort avoidance check in UnitsSelectionSorting (USS) The previous check `np.diff(self.ids_to_indices(self._renamed_unit_ids)).min() < 0` was never `True`, because `ids_to_indices(self._renamed_unit_ids)` on a USS always returns `[0, 1, ..., k-1]` (since `_main_ids == _renamed_unit_ids`), so the diff was always positive and the lexsort branch was unreachable. Therefore the cached spike vector was wrong whenever two units had co-temporal spikes and the selection reordered them relative to the parent. Replaced with a two-step check that attempt to avoid unneccessary lexsorts: 1. O(k) `_is_order_preserving_selection()` -- Checks if USS `._unit_ids` is in the same relative order as in the parent. When True, the remapped vector is guaranteed sorted (boolean filtering preserves order; the remap only relabels unit_index values). This is the common case via `select_units()` with a boolean mask. 2. O(n) `_is_spike_vector_sorted()` -- Checks if the remapped vector is still sorted by (segment, sample, unit). Catches the case where the selection is not order-preserving but no co-temporal (same exact sample) spikes exist. Falls back to the original O(n log n) lexsort only when both checks fail. --- .../core/unitsselectionsorting.py | 58 +++++++++++++++++-- 1 file changed, 52 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index a5ae75873c..61f8d6e57e 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -3,6 +3,31 @@ from .basesorting import BaseSorting, BaseSortingSegment +def _is_spike_vector_sorted(spike_vector: np.ndarray) -> bool: + """Return True iff the spike vector is sorted by (segment_index, sample_index, unit_index). + + O(n) sequential scan. Used to avoid an O(n log n) lexsort when the vector already + happens to be in canonical order. + """ + n = len(spike_vector) + if n <= 1: + return True + seg = spike_vector["segment_index"] + samp = spike_vector["sample_index"] + unit = spike_vector["unit_index"] + d_seg = np.diff(seg) + if np.any(d_seg < 0): + return False + seg_eq = d_seg == 0 + d_samp = np.diff(samp) + if np.any(d_samp[seg_eq] < 0): + return False + samp_eq = seg_eq & (d_samp == 0) + if np.any(np.diff(unit)[samp_eq] < 0): + return False + return True + + class UnitsSelectionSorting(BaseSorting): """ Class that handles slicing of a Sorting object based on a list of unit_ids. @@ -59,15 +84,36 @@ def _compute_and_cache_spike_vector(self) -> None: all_old_unit_ids=self._parent_sorting.unit_ids, all_new_unit_ids=self._unit_ids, ) - # lexsort by segment_index, sample_index, unit_index, only if needed - # (remapping can change the order of unit indices) - if len(self._renamed_unit_ids) > 1 and np.diff(self.ids_to_indices(self._renamed_unit_ids)).min() < 0: + + # The parent's spike vector is sorted by (segment_index, sample_index, unit_index). + # Boolean filtering by unit preserves that order; the remap only changes unit_index + # values. The result stays sorted iff the selected unit_ids appear in the same + # relative order as in the parent (an O(k) check). If not, the vector may still + # happen to be sorted -- verify with an O(n) scan before falling back to O(n log n) + # lexsort. + if not self._is_order_preserving_selection() and not _is_spike_vector_sorted(spike_vector): sort_indices = np.lexsort( (spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"]) ) - self._cached_spike_vector = spike_vector[sort_indices] - else: - self._cached_spike_vector = spike_vector + spike_vector = spike_vector[sort_indices] + + self._cached_spike_vector = spike_vector + + def _is_order_preserving_selection(self) -> bool: + """Return True iff self._unit_ids appear in the same relative order as in the parent. + + O(k) where k is the number of selected units. When True, the remapped spike vector + is guaranteed to remain sorted by (segment, sample, unit) without re-sorting. + """ + parent_unit_ids = self._parent_sorting.unit_ids + parent_id_to_pos = {uid: i for i, uid in enumerate(parent_unit_ids)} + prev_pos = -1 + for uid in self._unit_ids: + pos = parent_id_to_pos.get(uid) + if pos is None or pos <= prev_pos: + return False + prev_pos = pos + return True class UnitsSelectionSortingSegment(BaseSortingSegment): From 6a825778cefc9ef9a315f6acc0f6c47eeff02eb1 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Fri, 10 Apr 2026 16:54:53 -0500 Subject: [PATCH 08/12] Override _compute_and_cache_spike_vector in Phy/Kilosort extractors `BaseSorting` builds the spike vector with a per-unit boolean scan over spike_clusters, which is (O(N*K)). If we already have the full flat spike time and spike cluster arrays, we can do a lot better by building the spike vector in one shot. (I think O(N log N) from the lexsort, which is also pessimistic, because the lexsort doesn't always need to happen. Under any circumstances I can dream of, K >> log N.) Since Phy/Kilosort segments already load the full flat arrays when the `PhyKilosortSorting` object is created, and keep them around as `._all_spikes` and `._all_clusters`, we can just use those! :) Also populates `_cached_spike_vector_segment_slices` directly, so that `BaseSorting`'s `_get_spike_vector_segment_slices()` lazy recomputation is skipped. --- .../extractors/phykilosortextractors.py | 45 +++++++++++++++++++ 1 file changed, 45 insertions(+) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 0e5dd2694d..61dd8fce80 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -13,6 +13,7 @@ create_sorting_analyzer, SortingAnalyzer, ) +from spikeinterface.core.base import minimum_spike_dtype from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations @@ -224,6 +225,50 @@ def __init__( self.add_sorting_segment(PhySortingSegment(spike_times_clean, spike_clusters_clean)) + def _compute_and_cache_spike_vector(self) -> None: + """Build the spike vector directly from the flat per-segment arrays. + + Since Phy/Kilosort segments already hold the full spike_times and + spike_clusters arrays in memory, we can construct the spike vector + in one shot. + """ + unit_ids = np.asarray(self.unit_ids) + sorter = np.argsort(unit_ids) + sorted_unit_ids = unit_ids[sorter] + + num_seg = self.get_num_segments() + spikes_list = [] + segment_slices = np.zeros((num_seg, 2), dtype="int64") + pos = 0 + + for seg_idx in range(num_seg): + seg = self.segments[seg_idx] + all_spikes = seg._all_spikes + all_clusters = seg._all_clusters + + # Map cluster ids -> unit indices. `spike_clusters_clean` is guaranteed + # to only contain ids present in `self.unit_ids` (filtered in __init__), + # so searchsorted always returns a valid position. + unit_indices = sorter[np.searchsorted(sorted_unit_ids, all_clusters)] + + n = all_spikes.size + segment_slices[seg_idx] = [pos, pos + n] + pos += n + + seg_spikes = np.zeros(n, dtype=minimum_spike_dtype) + seg_spikes["sample_index"] = all_spikes + seg_spikes["unit_index"] = unit_indices + seg_spikes["segment_index"] = seg_idx + spikes_list.append(seg_spikes) + + spikes = np.concatenate(spikes_list) if spikes_list else np.zeros(0, dtype=minimum_spike_dtype) + # Canonical order: (segment_index, sample_index, unit_index). + order = np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"])) + spikes = spikes[order] + + self._cached_spike_vector = spikes + self._cached_spike_vector_segment_slices = segment_slices + class PhySortingSegment(BaseSortingSegment): def __init__(self, all_spikes, all_clusters): From 1d4a3ceb123d89a538cc9de64b1c97d812ba1ae8 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Fri, 10 Apr 2026 18:01:13 -0500 Subject: [PATCH 09/12] Optimize get_unit_spike_trains on PhySortingSegment `BaseSortingSegment.get_unit_spike_trains()` loops over `get_unit_spike_train`, which is O(N*K) because each call is a boolean scan over _all_clusters/_all_spikes. If we know we are going to be getting all the trains, we can do it much faster. And if we can use numba, even faster still. In fact, even if we only want _some_ spike trains, it is still often faster to get all the trains and just discard the ones we don't need, than to get only the trains we need do unit-by-unit (because we only ever store or cache flat arrays of spike times/clusters). Note that **only the use_cache=False path is affected**; the use_cache=True triggers the computation of the spike vector, which I don't think can ever be the most efficient way to get spike trains. --- .../extractors/phykilosortextractors.py | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 61dd8fce80..04a2d47fe3 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -1,3 +1,4 @@ +import importlib.util from pathlib import Path import warnings @@ -19,6 +20,8 @@ from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations from probeinterface import read_prb, Probe +HAVE_NUMBA = importlib.util.find_spec("numba") is not None + class BasePhyKilosortSortingExtractor(BaseSorting): """Base SortingExtractor for Phy and Kilosort output folder. @@ -285,6 +288,104 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): spike_times = self._all_spikes[start:end][self._all_clusters[start:end] == unit_id] return np.atleast_1d(spike_times.copy().squeeze()) + def get_unit_spike_trains( + self, + unit_ids, + start_frame: int | None = None, + end_frame: int | None = None, + ) -> dict: + """Extract spike trains for several units in one pass. + + If you need to get ~20 or more spike trains, this is usually **much** faster + than calling get_unit_spike_train() for each unit. + + Numba-accelerated, if numba is available. Otherwise, falls back to NumPy. + """ + start = 0 if start_frame is None else np.searchsorted(self._all_spikes, start_frame, side="left") + end = ( + len(self._all_spikes) if end_frame is None else np.searchsorted(self._all_spikes, end_frame, side="left") + ) # Exclude end frame + + spikes = self._all_spikes[start:end] + clusters = self._all_clusters[start:end] + + unit_ids_arr = np.asarray(unit_ids) + num_units = len(unit_ids_arr) + if num_units == 0: + return {} + + # Map each spike's cluster id to a destination index in the caller-supplied + # unit_ids order. -1 means "this spike's cluster is not in unit_ids, skip it". + sorter = np.argsort(unit_ids_arr, kind="stable") + sorted_unit_ids = unit_ids_arr[sorter] + idx_in_sorted = np.searchsorted(sorted_unit_ids, clusters, side="left") + idx_clamped = np.minimum(idx_in_sorted, num_units - 1) + matches = (idx_in_sorted < num_units) & (sorted_unit_ids[idx_clamped] == clusters) + dest = np.where(matches, sorter[idx_clamped], -1).astype(np.int64) + + spikes_i64 = np.ascontiguousarray(spikes, dtype=np.int64) + + if HAVE_NUMBA: + offsets, flat_out = _counting_sort_spikes_by_unit(spikes_i64, dest, num_units) + else: + # NumPy fallback: stable argsort by destination index, then split on offsets. + # Stable sort preserves the input order of spikes within each unit group, + # and since _all_spikes is sorted by sample_index, so is each group. + valid = dest >= 0 + valid_spikes = spikes_i64[valid] + valid_dest = dest[valid] + order = np.argsort(valid_dest, kind="stable") + flat_out = valid_spikes[order] + counts = np.bincount(valid_dest, minlength=num_units) + offsets = np.empty(num_units + 1, dtype=np.int64) + offsets[0] = 0 + np.cumsum(counts, out=offsets[1:]) + + return {unit_ids[i]: flat_out[offsets[i] : offsets[i + 1]] for i in range(num_units)} + + +if HAVE_NUMBA: + import numba + + @numba.jit(nopython=True, nogil=True, cache=True) + def _counting_sort_spikes_by_unit(all_spikes, dest_unit_indices, num_units): + """Counting-sort `all_spikes` into per-unit groups. + + Parameters + ---------- + all_spikes : int64 array + Spike sample indices. + dest_unit_indices : int64 array (same length as all_spikes) + Destination unit index for each spike, or -1 to skip. + num_units : int + Number of destination units. + + Returns + ------- + offsets : int64 array of shape (num_units + 1,) + Offsets into `flat_out`; group k is `flat_out[offsets[k]:offsets[k+1]]`. + flat_out : int64 array + Concatenated spike times, grouped by destination unit index. + """ + n = all_spikes.shape[0] + counts = np.zeros(num_units + 1, dtype=np.int64) + for i in range(n): + u = dest_unit_indices[i] + if u >= 0: + counts[u + 1] += 1 + for k in range(1, num_units + 1): + counts[k] += counts[k - 1] + + flat_out = np.empty(counts[num_units], dtype=all_spikes.dtype) + write_pos = counts[:-1].copy() + for i in range(n): + u = dest_unit_indices[i] + if u >= 0: + flat_out[write_pos[u]] = all_spikes[i] + write_pos[u] += 1 + + return counts, flat_out + class PhySortingExtractor(BasePhyKilosortSortingExtractor): """Load Phy format data as a sorting extractor. From 9a139b5dd98443714445e363bc7758bf01afcb3c Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Mon, 13 Apr 2026 17:54:56 -0500 Subject: [PATCH 10/12] Add tests for UnitSelectionSorting & Phy spike vector and train optimizations - Fixed test_compute_and_cache_spike_vector: was comparing an array to itself (to_spike_vector use_cache=False still returns the cached vector). Now explicitly calls the USS override and the BaseSorting implementation, and compares the two. - Added test_uss_get_unit_spike_trains_with_renamed_ids: also not a test of the optimization commits per se, but would have caught a mistake made along the way. Verifies get_unit_spike_trains returns child-keyed dicts (not parent-keyed). - Added test_spike_vector_sorted_after_reorder_with_cotemporal_spikes: verifies the USS spike vector is correctly sorted when the selection reverses unit order and co-temporal spikes exist. - Added test_phy_sorting_segment_get_unit_spike_trains: validates the new fast methods on PhySortingSegment. - Added test_phy_compute_and_cache_spike_vector: verifies the Phy override of _compute_and_cache_spike_vector matches BaseSorting implementation. --- .../core/tests/test_unitsselectionsorting.py | 71 ++++++++++++++++- .../tests/test_phykilosortextractors.py | 78 +++++++++++++++++++ 2 files changed, 146 insertions(+), 3 deletions(-) create mode 100644 src/spikeinterface/extractors/tests/test_phykilosortextractors.py diff --git a/src/spikeinterface/core/tests/test_unitsselectionsorting.py b/src/spikeinterface/core/tests/test_unitsselectionsorting.py index 3aa7bc7577..b5a7965ca5 100644 --- a/src/spikeinterface/core/tests/test_unitsselectionsorting.py +++ b/src/spikeinterface/core/tests/test_unitsselectionsorting.py @@ -3,6 +3,7 @@ from pathlib import Path from spikeinterface.core import UnitsSelectionSorting +from spikeinterface.core.numpyextractors import NumpySorting from spikeinterface.core.generate import generate_sorting @@ -44,12 +45,76 @@ def test_failure_with_non_unique_unit_ids(): def test_compute_and_cache_spike_vector(): + """USS override of _compute_and_cache_spike_vector must produce the same + spike vector as the base class (per-unit) implementation.""" + from spikeinterface.core.basesorting import BaseSorting + sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0) sub_sorting = UnitsSelectionSorting(sorting, unit_ids=["2", "0"], renamed_unit_ids=["b", "a"]) - cached_spike_vector = sub_sorting.to_spike_vector(use_cache=True) - computed_spike_vector = sub_sorting.to_spike_vector(use_cache=False) - assert np.all(cached_spike_vector == computed_spike_vector) + + # USS override path + sub_sorting._compute_and_cache_spike_vector() + uss_vector = sub_sorting._cached_spike_vector.copy() + + # Base class (per-unit) path + sub_sorting._cached_spike_vector = None + sub_sorting._cached_spike_vector_segment_slices = None + BaseSorting._compute_and_cache_spike_vector(sub_sorting) + base_vector = sub_sorting._cached_spike_vector + + assert np.array_equal(uss_vector, base_vector) + + +@pytest.mark.parametrize("use_cache", [False, True]) +def test_uss_get_unit_spike_trains_with_renamed_ids(use_cache): + """get_unit_spike_trains on a USS with renamed ids must return dicts with child ids + (as opposed to parent ids) as keys.""" + sorting = generate_sorting(num_units=5, durations=[0.100], sampling_frequency=30000.0, seed=42) + + # Select a subset and rename + sub = UnitsSelectionSorting(sorting, unit_ids=["1", "3", "4"], renamed_unit_ids=["a", "b", "c"]) + renamed_ids = list(sub.unit_ids) + + batch = sub.get_unit_spike_trains(unit_ids=renamed_ids, segment_index=0, use_cache=use_cache) + + assert isinstance(batch, dict) + assert set(batch.keys()) == set(renamed_ids) + + for uid in renamed_ids: + single = sub.get_unit_spike_train(unit_id=uid, segment_index=0, use_cache=use_cache) + assert np.array_equal(batch[uid], single), f"Mismatch for unit {uid}" + + +def test_spike_vector_sorted_after_reorder_with_cotemporal_spikes(): + """USS spike vector must be correctly sorted even when selection reverses unit order + and co-temporal spikes exist (same sample_index, different units).""" + # Build a sorting with guaranteed co-temporal spikes: + # units 0, 1, 2 all fire at sample 100 and 200 + samples = np.array([100, 100, 100, 200, 200, 200, 300, 400], dtype=np.int64) + labels = np.array([0, 1, 2, 0, 1, 2, 0, 1], dtype=np.int64) + sorting = NumpySorting.from_samples_and_labels( + samples_list=[samples], labels_list=[labels], sampling_frequency=30000.0 + ) + + # Reverse the unit order — _is_order_preserving_selection must return False + sub = UnitsSelectionSorting(sorting, unit_ids=[2, 0], renamed_unit_ids=["b", "a"]) + + spike_vector = sub.to_spike_vector() + + # Spike vector must be sorted by (segment_index, sample_index, unit_index) + n = len(spike_vector) + if n > 1: + seg = spike_vector["segment_index"] + samp = spike_vector["sample_index"] + unit = spike_vector["unit_index"] + d_seg = np.diff(seg) + assert np.all(d_seg >= 0), "segment_index not non-decreasing" + seg_eq = d_seg == 0 + d_samp = np.diff(samp) + assert np.all(d_samp[seg_eq] >= 0), "sample_index not non-decreasing within segment" + samp_eq = seg_eq & (d_samp == 0) + assert np.all(np.diff(unit)[samp_eq] >= 0), "unit_index not non-decreasing within same sample" if __name__ == "__main__": diff --git a/src/spikeinterface/extractors/tests/test_phykilosortextractors.py b/src/spikeinterface/extractors/tests/test_phykilosortextractors.py new file mode 100644 index 0000000000..095e79cac1 --- /dev/null +++ b/src/spikeinterface/extractors/tests/test_phykilosortextractors.py @@ -0,0 +1,78 @@ +import pytest +import numpy as np + +from spikeinterface.extractors.phykilosortextractors import PhySortingSegment +import spikeinterface.extractors.phykilosortextractors as phymod + +# Sorted spike times with known cluster assignments. +# 3 units (ids 10, 20, 30), some co-temporal spikes. +ALL_SPIKES = np.array([100, 100, 200, 300, 300, 300, 400, 500], dtype=np.int64) +ALL_CLUSTERS = np.array([10, 20, 30, 10, 20, 30, 10, 20], dtype=np.int64) +UNIT_IDS = [10, 20, 30] + + +@pytest.mark.parametrize("force_numpy_fallback", [False, True]) +def test_phy_sorting_segment_get_unit_spike_trains(monkeypatch, force_numpy_fallback): + """get_unit_spike_trains must match per-unit calls, for both Numba and NumPy paths.""" + if force_numpy_fallback: + monkeypatch.setattr(phymod, "HAVE_NUMBA", False) + + seg = PhySortingSegment(ALL_SPIKES, ALL_CLUSTERS) + + # Full range, all units + batch = seg.get_unit_spike_trains(UNIT_IDS, start_frame=None, end_frame=None) + assert set(batch.keys()) == set(UNIT_IDS) + for uid in UNIT_IDS: + single = seg.get_unit_spike_train(uid, start_frame=None, end_frame=None) + assert np.array_equal(batch[uid], single), f"Mismatch for unit {uid}" + + assert np.array_equal(batch[10], [100, 300, 400]) + assert np.array_equal(batch[20], [100, 300, 500]) + assert np.array_equal(batch[30], [200, 300]) + + # With start_frame / end_frame slicing + batch_sliced = seg.get_unit_spike_trains(UNIT_IDS, start_frame=200, end_frame=400) + assert np.array_equal(batch_sliced[10], [300]) + assert np.array_equal(batch_sliced[20], [300]) + assert np.array_equal(batch_sliced[30], [200, 300]) + + # Subset of unit_ids + batch_subset = seg.get_unit_spike_trains([20], start_frame=None, end_frame=None) + assert list(batch_subset.keys()) == [20] + assert np.array_equal(batch_subset[20], [100, 300, 500]) + + # Empty unit_ids + assert seg.get_unit_spike_trains([], start_frame=None, end_frame=None) == {} + + +def _make_phy_folder(tmp_path): + """Create a minimal Phy output folder for testing.""" + spike_times = np.array([100, 100, 200, 300, 300, 300, 400, 500], dtype=np.int64) + spike_clusters = np.array([10, 20, 30, 10, 20, 30, 10, 20], dtype=np.int64) + + np.save(tmp_path / "spike_times.npy", spike_times) + np.save(tmp_path / "spike_clusters.npy", spike_clusters) + (tmp_path / "params.py").write_text("sample_rate = 30000.0\n") + return tmp_path + + +def test_phy_compute_and_cache_spike_vector(tmp_path): + """Phy override of _compute_and_cache_spike_vector must produce the same + spike vector as the base class (per-unit) implementation.""" + from spikeinterface.core.basesorting import BaseSorting + from spikeinterface.extractors.phykilosortextractors import BasePhyKilosortSortingExtractor + + phy_folder = _make_phy_folder(tmp_path) + sorting = BasePhyKilosortSortingExtractor(phy_folder) + + # Phy override path + sorting._compute_and_cache_spike_vector() + phy_vector = sorting._cached_spike_vector.copy() + + # Base class (per-unit) path + sorting._cached_spike_vector = None + sorting._cached_spike_vector_segment_slices = None + BaseSorting._compute_and_cache_spike_vector(sorting) + base_vector = sorting._cached_spike_vector + + assert np.array_equal(phy_vector, base_vector) From 832f44f7a3e94fb9b11abb7371fe6b5ccf6a6608 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 14 Apr 2026 11:34:16 +0200 Subject: [PATCH 11/12] Move is_spike_vector_sorted to sorting_tools --- src/spikeinterface/core/sorting_tools.py | 25 ++++++++++++++++ .../core/unitsselectionsorting.py | 30 ++----------------- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index cbad29c806..e94ebe2950 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -996,3 +996,28 @@ def remap_unit_indices_in_vector(vector, all_old_unit_ids, all_new_unit_ids, kee new_vector["unit_index"] = mapping[new_vector["unit_index"]] return new_vector, keep_mask_vector + + +def is_spike_vector_sorted(spike_vector: np.ndarray) -> bool: + """Return True iff the spike vector is sorted by (segment_index, sample_index, unit_index). + + O(n) sequential scan. Used to avoid an O(n log n) lexsort when the vector already + happens to be in canonical order. + """ + n = len(spike_vector) + if n <= 1: + return True + seg = spike_vector["segment_index"] + samp = spike_vector["sample_index"] + unit = spike_vector["unit_index"] + d_seg = np.diff(seg) + if np.any(d_seg < 0): + return False + seg_eq = d_seg == 0 + d_samp = np.diff(samp) + if np.any(d_samp[seg_eq] < 0): + return False + samp_eq = seg_eq & (d_samp == 0) + if np.any(np.diff(unit)[samp_eq] < 0): + return False + return True diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index 61f8d6e57e..2e9e4df5d8 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -1,31 +1,7 @@ import numpy as np from .basesorting import BaseSorting, BaseSortingSegment - - -def _is_spike_vector_sorted(spike_vector: np.ndarray) -> bool: - """Return True iff the spike vector is sorted by (segment_index, sample_index, unit_index). - - O(n) sequential scan. Used to avoid an O(n log n) lexsort when the vector already - happens to be in canonical order. - """ - n = len(spike_vector) - if n <= 1: - return True - seg = spike_vector["segment_index"] - samp = spike_vector["sample_index"] - unit = spike_vector["unit_index"] - d_seg = np.diff(seg) - if np.any(d_seg < 0): - return False - seg_eq = d_seg == 0 - d_samp = np.diff(samp) - if np.any(d_samp[seg_eq] < 0): - return False - samp_eq = seg_eq & (d_samp == 0) - if np.any(np.diff(unit)[samp_eq] < 0): - return False - return True +from .sorting_tools import is_spike_vector_sorted class UnitsSelectionSorting(BaseSorting): @@ -91,7 +67,7 @@ def _compute_and_cache_spike_vector(self) -> None: # relative order as in the parent (an O(k) check). If not, the vector may still # happen to be sorted -- verify with an O(n) scan before falling back to O(n log n) # lexsort. - if not self._is_order_preserving_selection() and not _is_spike_vector_sorted(spike_vector): + if not self._is_order_preserving_selection() and not is_spike_vector_sorted(spike_vector): sort_indices = np.lexsort( (spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"]) ) @@ -100,7 +76,7 @@ def _compute_and_cache_spike_vector(self) -> None: self._cached_spike_vector = spike_vector def _is_order_preserving_selection(self) -> bool: - """Return True iff self._unit_ids appear in the same relative order as in the parent. + """Return True if self._unit_ids appear in the same relative order as in the parent. O(k) where k is the number of selected units. When True, the remapped spike vector is guaranteed to remain sorted by (segment, sample, unit) without re-sorting. From 329d220fb1a5fe9a9b1ab346c916065cb1126f16 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 14 Apr 2026 09:37:17 +0000 Subject: [PATCH 12/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/core/tests/test_basesorting.py | 4 +--- src/spikeinterface/extractors/phykilosortextractors.py | 10 +++++----- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 247fa671f6..10874a8362 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -325,9 +325,7 @@ def test_get_unit_spike_trains(use_cache): assert np.array_equal(all_spike_trains[unit_id], spiketrain) # test with times - spike_trains_times = sorting.get_unit_spike_trains_in_seconds( - unit_ids=sorting.unit_ids, use_cache=use_cache - ) + spike_trains_times = sorting.get_unit_spike_trains_in_seconds(unit_ids=sorting.unit_ids, use_cache=use_cache) assert isinstance(spike_trains_times, dict) assert set(spike_trains_times.keys()) == set(sorting.unit_ids) for unit_id in sorting.unit_ids: diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 04a2d47fe3..1530072288 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -231,9 +231,9 @@ def __init__( def _compute_and_cache_spike_vector(self) -> None: """Build the spike vector directly from the flat per-segment arrays. - Since Phy/Kilosort segments already hold the full spike_times and + Since Phy/Kilosort segments already hold the full spike_times and spike_clusters arrays in memory, we can construct the spike vector - in one shot. + in one shot. """ unit_ids = np.asarray(self.unit_ids) sorter = np.argsort(unit_ids) @@ -295,11 +295,11 @@ def get_unit_spike_trains( end_frame: int | None = None, ) -> dict: """Extract spike trains for several units in one pass. - - If you need to get ~20 or more spike trains, this is usually **much** faster + + If you need to get ~20 or more spike trains, this is usually **much** faster than calling get_unit_spike_train() for each unit. - Numba-accelerated, if numba is available. Otherwise, falls back to NumPy. + Numba-accelerated, if numba is available. Otherwise, falls back to NumPy. """ start = 0 if start_frame is None else np.searchsorted(self._all_spikes, start_frame, side="left") end = (