diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index cb68f3d455..26de57c24d 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -180,6 +180,7 @@ def get_unit_spike_train( segment_index=segment_index, start_time=start_time, end_time=end_time, + use_cache=use_cache, ) segment_index = self._check_segment_index(segment_index) @@ -212,6 +213,7 @@ def get_unit_spike_train_in_seconds( segment_index: int | None = None, start_time: float | None = None, end_time: float | None = None, + use_cache: bool = True, ) -> np.ndarray: """ Get spike train for a unit in seconds. @@ -236,6 +238,8 @@ def get_unit_spike_train_in_seconds( The start time in seconds for spike train extraction end_time : float or None, default: None The end time in seconds for spike train extraction + use_cache : bool, default: True + If True, precompute (or use) the reordered spike vector cache for fast access. Returns ------- @@ -246,7 +250,7 @@ def get_unit_spike_train_in_seconds( segment = self.segments[segment_index] # If sorting has a registered recording, get the frames and get the times from the recording - # Note that this take into account the segment start time of the recording + # Note that this takes into account the segment start time of the recording if self.has_recording(): # Get all the spike times and then slice them @@ -258,7 +262,7 @@ def get_unit_spike_train_in_seconds( start_frame=start_frame, end_frame=end_frame, return_times=False, - use_cache=True, + use_cache=use_cache, ) spike_times = self.sample_index_to_time(spike_frames, segment_index=segment_index) @@ -288,13 +292,169 @@ def get_unit_spike_train_in_seconds( start_frame=start_frame, end_frame=end_frame, return_times=False, - use_cache=True, + use_cache=use_cache, ) t_start = segment._t_start if segment._t_start is not None else 0 spike_times = spike_frames / self.get_sampling_frequency() return t_start + spike_times + def get_unit_spike_trains( + self, + unit_ids: np.ndarray | list, + segment_index: int | None = None, + start_frame: int | None = None, + end_frame: int | None = None, + return_times: bool = False, + use_cache: bool = True, + ) -> dict[int | str, np.ndarray]: + """Return spike trains for multiple units. + + Parameters + ---------- + unit_ids : np.ndarray | list + Unit ids to retrieve spike trains for + segment_index : int or None, default: None + The segment index to retrieve spike train from. + For multi-segment objects, it is required + start_frame : int or None, default: None + The start frame for spike train extraction + end_frame : int or None, default: None + The end frame for spike train extraction + return_times : bool, default: False + If True, returns spike times in seconds instead of frames + use_cache : bool, default: True + If True, precompute (or use) the reordered spike vector cache for fast access. + + Returns + ------- + dict[int | str, np.ndarray] + A dictionary where keys are unit ids and values are spike trains (arrays of spike times or frames) + """ + if return_times: + start_time = ( + self.sample_index_to_time(start_frame, segment_index=segment_index) if start_frame is not None else None + ) + end_time = ( + self.sample_index_to_time(end_frame, segment_index=segment_index) if end_frame is not None else None + ) + + return self.get_unit_spike_trains_in_seconds( + unit_ids=unit_ids, + segment_index=segment_index, + start_time=start_time, + end_time=end_time, + use_cache=use_cache, + ) + + segment_index = self._check_segment_index(segment_index) + segment = self.segments[segment_index] + if use_cache: + # TODO: speed things up + ordered_spike_vector, slices = self.to_reordered_spike_vector( + lexsort=("sample_index", "segment_index", "unit_index"), + return_order=False, + return_slices=True, + ) + unit_indices = self.ids_to_indices(unit_ids) + spike_trains = {} + for unit_index, unit_id in zip(unit_indices, unit_ids): + sl0, sl1 = slices[unit_index, segment_index, :] + spikes = ordered_spike_vector[sl0:sl1] + spike_frames = spikes["sample_index"] + if start_frame is not None: + start = np.searchsorted(spike_frames, start_frame) + spike_frames = spike_frames[start:] + if end_frame is not None: + end = np.searchsorted(spike_frames, end_frame) + spike_frames = spike_frames[:end] + spike_trains[unit_id] = spike_frames + else: + spike_trains = segment.get_unit_spike_trains( + unit_ids=unit_ids, start_frame=start_frame, end_frame=end_frame + ) + return spike_trains + + def get_unit_spike_trains_in_seconds( + self, + unit_ids: np.ndarray | list, + segment_index: int | None = None, + start_time: float | None = None, + end_time: float | None = None, + use_cache: bool = True, + ) -> dict[int | str, np.ndarray]: + """Return spike trains for multiple units in seconds. + + Parameters + ---------- + unit_ids : np.ndarray | list + Unit ids to retrieve spike trains for + segment_index : int or None, default: None + The segment index to retrieve spike train from. + For multi-segment objects, it is required + start_time : float or None, default: None + The start time in seconds for spike train extraction + end_time : float or None, default: None + The end time in seconds for spike train extraction + use_cache : bool, default: True + If True, precompute (or use) the reordered spike vector cache for fast access. + + Returns + ------- + dict[int | str, np.ndarray] + A dictionary where keys are unit ids and values are spike trains (arrays of spike times in seconds) + """ + segment_index = self._check_segment_index(segment_index) + segment = self.segments[segment_index] + + # If sorting has a registered recording, get the frames and get the times from the recording + # Note that this takes into account the segment start time of the recording + spike_times = {} + if self.has_recording(): + # Get all the spike times and then slice them + start_frame = None + end_frame = None + spike_train_frames = self.get_unit_spike_trains( + unit_ids=unit_ids, + segment_index=segment_index, + start_frame=start_frame, + end_frame=end_frame, + return_times=False, + use_cache=use_cache, + ) + + for unit_id in unit_ids: + spike_frames = self.sample_index_to_time(spike_train_frames[unit_id], segment_index=segment_index) + + # Filter to return only the spikes within the specified time range + if start_time is not None: + spike_frames = spike_frames[spike_frames >= start_time] + if end_time is not None: + spike_frames = spike_frames[spike_frames <= end_time] + + spike_times[unit_id] = spike_frames + + return spike_times + + # If no recording attached and all back to frame-based conversion + # Get spike train in frames and convert to times using traditional method + start_frame = self.time_to_sample_index(start_time, segment_index=segment_index) if start_time else None + end_frame = self.time_to_sample_index(end_time, segment_index=segment_index) if end_time else None + + spike_frames = self.get_unit_spike_trains( + unit_ids=unit_ids, + segment_index=segment_index, + start_frame=start_frame, + end_frame=end_frame, + return_times=False, + use_cache=use_cache, + ) + for unit_id in unit_ids: + spike_frames_unit = spike_frames[unit_id] + t_start = segment._t_start if segment._t_start is not None else 0 + spike_times[unit_id] = spike_frames_unit / self.get_sampling_frequency() + t_start + return spike_times + def register_recording(self, recording, check_spike_frames: bool = True): """ Register a recording to the sorting. If the sorting and recording both contain @@ -978,7 +1138,7 @@ def to_reordered_spike_vector( s1 = seg_slices[segment_index + 1] slices[unit_index, segment_index, :] = [u0 + s0, u0 + s1] - elif ("sample_index", "unit_index", "segment_index"): + elif lexsort == ("sample_index", "unit_index", "segment_index"): slices = np.zeros((num_segments, num_units, 2), dtype=np.int64) seg_slices = np.searchsorted(ordered_spikes["segment_index"], np.arange(num_segments + 1), side="left") for segment_index in range(self.get_num_segments()): @@ -1083,7 +1243,7 @@ def __init__(self, t_start=None): def get_unit_spike_train( self, - unit_id, + unit_id: int | str, start_frame: int | None = None, end_frame: int | None = None, ) -> np.ndarray: @@ -1091,18 +1251,51 @@ def get_unit_spike_train( Parameters ---------- - unit_id + unit_id : int | str + The unit id for which to get the spike train. start_frame : int, default: None + The start frame for the spike train. If None, it is set to the beginning of the segment. end_frame : int, default: None + The end frame for the spike train. If None, it is set to the end of the segment. + Returns ------- np.ndarray - + The spike train for the given unit id and time interval. """ # must be implemented in subclass raise NotImplementedError + def get_unit_spike_trains( + self, + unit_ids: np.ndarray | list, + start_frame: int | None = None, + end_frame: int | None = None, + ) -> dict[int | str, np.ndarray]: + """Get the spike trains for several units. + Can be implemented in subclass for performance but the default implementation is to call + get_unit_spike_train for each unit_id. + + Parameters + ---------- + unit_ids : numpy.array or list + The unit ids for which to get the spike trains. + start_frame : int, default: None + The start frame for the spike trains. If None, it is set to the beginning of the segment. + end_frame : int, default: None + The end frame for the spike trains. If None, it is set to the end of the segment. + + Returns + ------- + dict[int | str, np.ndarray] + A dictionary where keys are unit_ids and values are the corresponding spike trains. + """ + spike_trains = {} + for unit_id in unit_ids: + spike_trains[unit_id] = self.get_unit_spike_train(unit_id, start_frame=start_frame, end_frame=end_frame) + return spike_trains + class SpikeVectorSortingSegment(BaseSortingSegment): """ diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index cbad29c806..e94ebe2950 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -996,3 +996,28 @@ def remap_unit_indices_in_vector(vector, all_old_unit_ids, all_new_unit_ids, kee new_vector["unit_index"] = mapping[new_vector["unit_index"]] return new_vector, keep_mask_vector + + +def is_spike_vector_sorted(spike_vector: np.ndarray) -> bool: + """Return True iff the spike vector is sorted by (segment_index, sample_index, unit_index). + + O(n) sequential scan. Used to avoid an O(n log n) lexsort when the vector already + happens to be in canonical order. + """ + n = len(spike_vector) + if n <= 1: + return True + seg = spike_vector["segment_index"] + samp = spike_vector["sample_index"] + unit = spike_vector["unit_index"] + d_seg = np.diff(seg) + if np.any(d_seg < 0): + return False + seg_eq = d_seg == 0 + d_samp = np.diff(samp) + if np.any(d_samp[seg_eq] < 0): + return False + samp_eq = seg_eq & (d_samp == 0) + if np.any(np.diff(unit)[samp_eq] < 0): + return False + return True diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 6c06b212b8..10874a8362 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -310,6 +310,31 @@ def test_select_periods(): np.testing.assert_array_equal(sliced_sorting.to_spike_vector(), sliced_sorting_array.to_spike_vector()) +@pytest.mark.parametrize("use_cache", [False, True]) +def test_get_unit_spike_trains(use_cache): + sampling_frequency = 10_000.0 + duration = 1.0 + num_units = 10 + sorting = generate_sorting(durations=[duration], sampling_frequency=sampling_frequency, num_units=num_units) + + all_spike_trains = sorting.get_unit_spike_trains(unit_ids=sorting.unit_ids, use_cache=use_cache) + assert isinstance(all_spike_trains, dict) + assert set(all_spike_trains.keys()) == set(sorting.unit_ids) + for unit_id in sorting.unit_ids: + spiketrain = sorting.get_unit_spike_train(segment_index=0, unit_id=unit_id, use_cache=use_cache) + assert np.array_equal(all_spike_trains[unit_id], spiketrain) + + # test with times + spike_trains_times = sorting.get_unit_spike_trains_in_seconds(unit_ids=sorting.unit_ids, use_cache=use_cache) + assert isinstance(spike_trains_times, dict) + assert set(spike_trains_times.keys()) == set(sorting.unit_ids) + for unit_id in sorting.unit_ids: + spiketrain_times = sorting.get_unit_spike_train_in_seconds( + segment_index=0, unit_id=unit_id, use_cache=use_cache + ) + assert np.allclose(spike_trains_times[unit_id], spiketrain_times) + + if __name__ == "__main__": import tempfile diff --git a/src/spikeinterface/core/tests/test_unitsselectionsorting.py b/src/spikeinterface/core/tests/test_unitsselectionsorting.py index 3aa7bc7577..b5a7965ca5 100644 --- a/src/spikeinterface/core/tests/test_unitsselectionsorting.py +++ b/src/spikeinterface/core/tests/test_unitsselectionsorting.py @@ -3,6 +3,7 @@ from pathlib import Path from spikeinterface.core import UnitsSelectionSorting +from spikeinterface.core.numpyextractors import NumpySorting from spikeinterface.core.generate import generate_sorting @@ -44,12 +45,76 @@ def test_failure_with_non_unique_unit_ids(): def test_compute_and_cache_spike_vector(): + """USS override of _compute_and_cache_spike_vector must produce the same + spike vector as the base class (per-unit) implementation.""" + from spikeinterface.core.basesorting import BaseSorting + sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0) sub_sorting = UnitsSelectionSorting(sorting, unit_ids=["2", "0"], renamed_unit_ids=["b", "a"]) - cached_spike_vector = sub_sorting.to_spike_vector(use_cache=True) - computed_spike_vector = sub_sorting.to_spike_vector(use_cache=False) - assert np.all(cached_spike_vector == computed_spike_vector) + + # USS override path + sub_sorting._compute_and_cache_spike_vector() + uss_vector = sub_sorting._cached_spike_vector.copy() + + # Base class (per-unit) path + sub_sorting._cached_spike_vector = None + sub_sorting._cached_spike_vector_segment_slices = None + BaseSorting._compute_and_cache_spike_vector(sub_sorting) + base_vector = sub_sorting._cached_spike_vector + + assert np.array_equal(uss_vector, base_vector) + + +@pytest.mark.parametrize("use_cache", [False, True]) +def test_uss_get_unit_spike_trains_with_renamed_ids(use_cache): + """get_unit_spike_trains on a USS with renamed ids must return dicts with child ids + (as opposed to parent ids) as keys.""" + sorting = generate_sorting(num_units=5, durations=[0.100], sampling_frequency=30000.0, seed=42) + + # Select a subset and rename + sub = UnitsSelectionSorting(sorting, unit_ids=["1", "3", "4"], renamed_unit_ids=["a", "b", "c"]) + renamed_ids = list(sub.unit_ids) + + batch = sub.get_unit_spike_trains(unit_ids=renamed_ids, segment_index=0, use_cache=use_cache) + + assert isinstance(batch, dict) + assert set(batch.keys()) == set(renamed_ids) + + for uid in renamed_ids: + single = sub.get_unit_spike_train(unit_id=uid, segment_index=0, use_cache=use_cache) + assert np.array_equal(batch[uid], single), f"Mismatch for unit {uid}" + + +def test_spike_vector_sorted_after_reorder_with_cotemporal_spikes(): + """USS spike vector must be correctly sorted even when selection reverses unit order + and co-temporal spikes exist (same sample_index, different units).""" + # Build a sorting with guaranteed co-temporal spikes: + # units 0, 1, 2 all fire at sample 100 and 200 + samples = np.array([100, 100, 100, 200, 200, 200, 300, 400], dtype=np.int64) + labels = np.array([0, 1, 2, 0, 1, 2, 0, 1], dtype=np.int64) + sorting = NumpySorting.from_samples_and_labels( + samples_list=[samples], labels_list=[labels], sampling_frequency=30000.0 + ) + + # Reverse the unit order — _is_order_preserving_selection must return False + sub = UnitsSelectionSorting(sorting, unit_ids=[2, 0], renamed_unit_ids=["b", "a"]) + + spike_vector = sub.to_spike_vector() + + # Spike vector must be sorted by (segment_index, sample_index, unit_index) + n = len(spike_vector) + if n > 1: + seg = spike_vector["segment_index"] + samp = spike_vector["sample_index"] + unit = spike_vector["unit_index"] + d_seg = np.diff(seg) + assert np.all(d_seg >= 0), "segment_index not non-decreasing" + seg_eq = d_seg == 0 + d_samp = np.diff(samp) + assert np.all(d_samp[seg_eq] >= 0), "sample_index not non-decreasing within segment" + samp_eq = seg_eq & (d_samp == 0) + assert np.all(np.diff(unit)[samp_eq] >= 0), "unit_index not non-decreasing within same sample" if __name__ == "__main__": diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index 59356db976..2e9e4df5d8 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -1,6 +1,7 @@ import numpy as np from .basesorting import BaseSorting, BaseSortingSegment +from .sorting_tools import is_spike_vector_sorted class UnitsSelectionSorting(BaseSorting): @@ -59,11 +60,36 @@ def _compute_and_cache_spike_vector(self) -> None: all_old_unit_ids=self._parent_sorting.unit_ids, all_new_unit_ids=self._unit_ids, ) - # lexsort by segment_index, sample_index, unit_index - sort_indices = np.lexsort( - (spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"]) - ) - self._cached_spike_vector = spike_vector[sort_indices] + + # The parent's spike vector is sorted by (segment_index, sample_index, unit_index). + # Boolean filtering by unit preserves that order; the remap only changes unit_index + # values. The result stays sorted iff the selected unit_ids appear in the same + # relative order as in the parent (an O(k) check). If not, the vector may still + # happen to be sorted -- verify with an O(n) scan before falling back to O(n log n) + # lexsort. + if not self._is_order_preserving_selection() and not is_spike_vector_sorted(spike_vector): + sort_indices = np.lexsort( + (spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"]) + ) + spike_vector = spike_vector[sort_indices] + + self._cached_spike_vector = spike_vector + + def _is_order_preserving_selection(self) -> bool: + """Return True if self._unit_ids appear in the same relative order as in the parent. + + O(k) where k is the number of selected units. When True, the remapped spike vector + is guaranteed to remain sorted by (segment, sample, unit) without re-sorting. + """ + parent_unit_ids = self._parent_sorting.unit_ids + parent_id_to_pos = {uid: i for i, uid in enumerate(parent_unit_ids)} + prev_pos = -1 + for uid in self._unit_ids: + pos = parent_id_to_pos.get(uid) + if pos is None or pos <= prev_pos: + return False + prev_pos = pos + return True class UnitsSelectionSortingSegment(BaseSortingSegment): @@ -81,3 +107,13 @@ def get_unit_spike_train( unit_id_parent = self._ids_conversion[unit_id] times = self._parent_segment.get_unit_spike_train(unit_id_parent, start_frame, end_frame) return times + + def get_unit_spike_trains( + self, + unit_ids, + start_frame: int | None = None, + end_frame: int | None = None, + ) -> dict: + unit_ids_parent = [self._ids_conversion[unit_id] for unit_id in unit_ids] + parent_trains = self._parent_segment.get_unit_spike_trains(unit_ids_parent, start_frame, end_frame) + return {child_id: parent_trains[parent_id] for child_id, parent_id in zip(unit_ids, unit_ids_parent)} diff --git a/src/spikeinterface/extractors/phykilosortextractors.py b/src/spikeinterface/extractors/phykilosortextractors.py index 0e5dd2694d..1530072288 100644 --- a/src/spikeinterface/extractors/phykilosortextractors.py +++ b/src/spikeinterface/extractors/phykilosortextractors.py @@ -1,3 +1,4 @@ +import importlib.util from pathlib import Path import warnings @@ -13,11 +14,14 @@ create_sorting_analyzer, SortingAnalyzer, ) +from spikeinterface.core.base import minimum_spike_dtype from spikeinterface.core.core_tools import define_function_from_class from spikeinterface.postprocessing import ComputeSpikeAmplitudes, ComputeSpikeLocations from probeinterface import read_prb, Probe +HAVE_NUMBA = importlib.util.find_spec("numba") is not None + class BasePhyKilosortSortingExtractor(BaseSorting): """Base SortingExtractor for Phy and Kilosort output folder. @@ -224,6 +228,50 @@ def __init__( self.add_sorting_segment(PhySortingSegment(spike_times_clean, spike_clusters_clean)) + def _compute_and_cache_spike_vector(self) -> None: + """Build the spike vector directly from the flat per-segment arrays. + + Since Phy/Kilosort segments already hold the full spike_times and + spike_clusters arrays in memory, we can construct the spike vector + in one shot. + """ + unit_ids = np.asarray(self.unit_ids) + sorter = np.argsort(unit_ids) + sorted_unit_ids = unit_ids[sorter] + + num_seg = self.get_num_segments() + spikes_list = [] + segment_slices = np.zeros((num_seg, 2), dtype="int64") + pos = 0 + + for seg_idx in range(num_seg): + seg = self.segments[seg_idx] + all_spikes = seg._all_spikes + all_clusters = seg._all_clusters + + # Map cluster ids -> unit indices. `spike_clusters_clean` is guaranteed + # to only contain ids present in `self.unit_ids` (filtered in __init__), + # so searchsorted always returns a valid position. + unit_indices = sorter[np.searchsorted(sorted_unit_ids, all_clusters)] + + n = all_spikes.size + segment_slices[seg_idx] = [pos, pos + n] + pos += n + + seg_spikes = np.zeros(n, dtype=minimum_spike_dtype) + seg_spikes["sample_index"] = all_spikes + seg_spikes["unit_index"] = unit_indices + seg_spikes["segment_index"] = seg_idx + spikes_list.append(seg_spikes) + + spikes = np.concatenate(spikes_list) if spikes_list else np.zeros(0, dtype=minimum_spike_dtype) + # Canonical order: (segment_index, sample_index, unit_index). + order = np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"])) + spikes = spikes[order] + + self._cached_spike_vector = spikes + self._cached_spike_vector_segment_slices = segment_slices + class PhySortingSegment(BaseSortingSegment): def __init__(self, all_spikes, all_clusters): @@ -240,6 +288,104 @@ def get_unit_spike_train(self, unit_id, start_frame, end_frame): spike_times = self._all_spikes[start:end][self._all_clusters[start:end] == unit_id] return np.atleast_1d(spike_times.copy().squeeze()) + def get_unit_spike_trains( + self, + unit_ids, + start_frame: int | None = None, + end_frame: int | None = None, + ) -> dict: + """Extract spike trains for several units in one pass. + + If you need to get ~20 or more spike trains, this is usually **much** faster + than calling get_unit_spike_train() for each unit. + + Numba-accelerated, if numba is available. Otherwise, falls back to NumPy. + """ + start = 0 if start_frame is None else np.searchsorted(self._all_spikes, start_frame, side="left") + end = ( + len(self._all_spikes) if end_frame is None else np.searchsorted(self._all_spikes, end_frame, side="left") + ) # Exclude end frame + + spikes = self._all_spikes[start:end] + clusters = self._all_clusters[start:end] + + unit_ids_arr = np.asarray(unit_ids) + num_units = len(unit_ids_arr) + if num_units == 0: + return {} + + # Map each spike's cluster id to a destination index in the caller-supplied + # unit_ids order. -1 means "this spike's cluster is not in unit_ids, skip it". + sorter = np.argsort(unit_ids_arr, kind="stable") + sorted_unit_ids = unit_ids_arr[sorter] + idx_in_sorted = np.searchsorted(sorted_unit_ids, clusters, side="left") + idx_clamped = np.minimum(idx_in_sorted, num_units - 1) + matches = (idx_in_sorted < num_units) & (sorted_unit_ids[idx_clamped] == clusters) + dest = np.where(matches, sorter[idx_clamped], -1).astype(np.int64) + + spikes_i64 = np.ascontiguousarray(spikes, dtype=np.int64) + + if HAVE_NUMBA: + offsets, flat_out = _counting_sort_spikes_by_unit(spikes_i64, dest, num_units) + else: + # NumPy fallback: stable argsort by destination index, then split on offsets. + # Stable sort preserves the input order of spikes within each unit group, + # and since _all_spikes is sorted by sample_index, so is each group. + valid = dest >= 0 + valid_spikes = spikes_i64[valid] + valid_dest = dest[valid] + order = np.argsort(valid_dest, kind="stable") + flat_out = valid_spikes[order] + counts = np.bincount(valid_dest, minlength=num_units) + offsets = np.empty(num_units + 1, dtype=np.int64) + offsets[0] = 0 + np.cumsum(counts, out=offsets[1:]) + + return {unit_ids[i]: flat_out[offsets[i] : offsets[i + 1]] for i in range(num_units)} + + +if HAVE_NUMBA: + import numba + + @numba.jit(nopython=True, nogil=True, cache=True) + def _counting_sort_spikes_by_unit(all_spikes, dest_unit_indices, num_units): + """Counting-sort `all_spikes` into per-unit groups. + + Parameters + ---------- + all_spikes : int64 array + Spike sample indices. + dest_unit_indices : int64 array (same length as all_spikes) + Destination unit index for each spike, or -1 to skip. + num_units : int + Number of destination units. + + Returns + ------- + offsets : int64 array of shape (num_units + 1,) + Offsets into `flat_out`; group k is `flat_out[offsets[k]:offsets[k+1]]`. + flat_out : int64 array + Concatenated spike times, grouped by destination unit index. + """ + n = all_spikes.shape[0] + counts = np.zeros(num_units + 1, dtype=np.int64) + for i in range(n): + u = dest_unit_indices[i] + if u >= 0: + counts[u + 1] += 1 + for k in range(1, num_units + 1): + counts[k] += counts[k - 1] + + flat_out = np.empty(counts[num_units], dtype=all_spikes.dtype) + write_pos = counts[:-1].copy() + for i in range(n): + u = dest_unit_indices[i] + if u >= 0: + flat_out[write_pos[u]] = all_spikes[i] + write_pos[u] += 1 + + return counts, flat_out + class PhySortingExtractor(BasePhyKilosortSortingExtractor): """Load Phy format data as a sorting extractor. diff --git a/src/spikeinterface/extractors/tests/test_phykilosortextractors.py b/src/spikeinterface/extractors/tests/test_phykilosortextractors.py new file mode 100644 index 0000000000..095e79cac1 --- /dev/null +++ b/src/spikeinterface/extractors/tests/test_phykilosortextractors.py @@ -0,0 +1,78 @@ +import pytest +import numpy as np + +from spikeinterface.extractors.phykilosortextractors import PhySortingSegment +import spikeinterface.extractors.phykilosortextractors as phymod + +# Sorted spike times with known cluster assignments. +# 3 units (ids 10, 20, 30), some co-temporal spikes. +ALL_SPIKES = np.array([100, 100, 200, 300, 300, 300, 400, 500], dtype=np.int64) +ALL_CLUSTERS = np.array([10, 20, 30, 10, 20, 30, 10, 20], dtype=np.int64) +UNIT_IDS = [10, 20, 30] + + +@pytest.mark.parametrize("force_numpy_fallback", [False, True]) +def test_phy_sorting_segment_get_unit_spike_trains(monkeypatch, force_numpy_fallback): + """get_unit_spike_trains must match per-unit calls, for both Numba and NumPy paths.""" + if force_numpy_fallback: + monkeypatch.setattr(phymod, "HAVE_NUMBA", False) + + seg = PhySortingSegment(ALL_SPIKES, ALL_CLUSTERS) + + # Full range, all units + batch = seg.get_unit_spike_trains(UNIT_IDS, start_frame=None, end_frame=None) + assert set(batch.keys()) == set(UNIT_IDS) + for uid in UNIT_IDS: + single = seg.get_unit_spike_train(uid, start_frame=None, end_frame=None) + assert np.array_equal(batch[uid], single), f"Mismatch for unit {uid}" + + assert np.array_equal(batch[10], [100, 300, 400]) + assert np.array_equal(batch[20], [100, 300, 500]) + assert np.array_equal(batch[30], [200, 300]) + + # With start_frame / end_frame slicing + batch_sliced = seg.get_unit_spike_trains(UNIT_IDS, start_frame=200, end_frame=400) + assert np.array_equal(batch_sliced[10], [300]) + assert np.array_equal(batch_sliced[20], [300]) + assert np.array_equal(batch_sliced[30], [200, 300]) + + # Subset of unit_ids + batch_subset = seg.get_unit_spike_trains([20], start_frame=None, end_frame=None) + assert list(batch_subset.keys()) == [20] + assert np.array_equal(batch_subset[20], [100, 300, 500]) + + # Empty unit_ids + assert seg.get_unit_spike_trains([], start_frame=None, end_frame=None) == {} + + +def _make_phy_folder(tmp_path): + """Create a minimal Phy output folder for testing.""" + spike_times = np.array([100, 100, 200, 300, 300, 300, 400, 500], dtype=np.int64) + spike_clusters = np.array([10, 20, 30, 10, 20, 30, 10, 20], dtype=np.int64) + + np.save(tmp_path / "spike_times.npy", spike_times) + np.save(tmp_path / "spike_clusters.npy", spike_clusters) + (tmp_path / "params.py").write_text("sample_rate = 30000.0\n") + return tmp_path + + +def test_phy_compute_and_cache_spike_vector(tmp_path): + """Phy override of _compute_and_cache_spike_vector must produce the same + spike vector as the base class (per-unit) implementation.""" + from spikeinterface.core.basesorting import BaseSorting + from spikeinterface.extractors.phykilosortextractors import BasePhyKilosortSortingExtractor + + phy_folder = _make_phy_folder(tmp_path) + sorting = BasePhyKilosortSortingExtractor(phy_folder) + + # Phy override path + sorting._compute_and_cache_spike_vector() + phy_vector = sorting._cached_spike_vector.copy() + + # Base class (per-unit) path + sorting._cached_spike_vector = None + sorting._cached_spike_vector_segment_slices = None + BaseSorting._compute_and_cache_spike_vector(sorting) + base_vector = sorting._cached_spike_vector + + assert np.array_equal(phy_vector, base_vector)