Implement get_unit_spike_trains and performance improvements#4502
Implement get_unit_spike_trains and performance improvements#4502alejoe91 wants to merge 24 commits intoSpikeInterface:mainfrom
get_unit_spike_trains and performance improvements#4502Conversation
- Drop unused `return_times` parameter from get_unit_spike_trains_in_seconds - Clean up stale/truncated docstrings on get_unit_spike_train_in_seconds, get_unit_spike_trains, and get_unit_spike_trains_in_seconds - Fix UnitsSelectionSortingSegment.get_unit_spike_trains to re-key the returned dict with child unit ids (was returning parent-keyed dict, breaking whenever renamed_unit_ids differ from parent ids) - Fix test_get_unit_spike_trains: drop unused return_times kwarg, remove unused local variable, fix assertion.
The previous check `np.diff(self.ids_to_indices(self._renamed_unit_ids)).min() < 0`
was never `True`, because `ids_to_indices(self._renamed_unit_ids)` on a USS
always returns `[0, 1, ..., k-1]` (since `_main_ids == _renamed_unit_ids`), so the
diff was always positive and the lexsort branch was unreachable. Therefore the
cached spike vector was wrong whenever two units had co-temporal spikes and the
selection reordered them relative to the parent.
Replaced with a two-step check that attempt to avoid unneccessary lexsorts:
1. O(k) `_is_order_preserving_selection()` -- Checks if USS `._unit_ids` is
in the same relative order as in the parent. When True, the remapped vector
is guaranteed sorted (boolean filtering preserves order; the remap only
relabels unit_index values). This is the common case via `select_units()`
with a boolean mask.
2. O(n) `_is_spike_vector_sorted()` -- Checks if the remapped vector is still
sorted by (segment, sample, unit). Catches the case where the selection is
not order-preserving but no co-temporal (same exact sample) spikes exist.
Falls back to the original O(n log n) lexsort only when both checks fail.
`BaseSorting` builds the spike vector with a per-unit boolean scan over spike_clusters, which is (O(N*K)). If we already have the full flat spike time and spike cluster arrays, we can do a lot better by building the spike vector in one shot. (I think O(N log N) from the lexsort, which is also pessimistic, because the lexsort doesn't always need to happen. Under any circumstances I can dream of, K >> log N.) Since Phy/Kilosort segments already load the full flat arrays when the `PhyKilosortSorting` object is created, and keep them around as `._all_spikes` and `._all_clusters`, we can just use those! :) Also populates `_cached_spike_vector_segment_slices` directly, so that `BaseSorting`'s `_get_spike_vector_segment_slices()` lazy recomputation is skipped.
`BaseSortingSegment.get_unit_spike_trains()` loops over `get_unit_spike_train`, which is O(N*K) because each call is a boolean scan over _all_clusters/_all_spikes. If we know we are going to be getting all the trains, we can do it much faster. And if we can use numba, even faster still. In fact, even if we only want _some_ spike trains, it is still often faster to get all the trains and just discard the ones we don't need, than to get only the trains we need do unit-by-unit (because we only ever store or cache flat arrays of spike times/clusters). Note that **only the use_cache=False path is affected**; the use_cache=True triggers the computation of the spike vector, which I don't think can ever be the most efficient way to get spike trains.
…izations - Fixed test_compute_and_cache_spike_vector: was comparing an array to itself (to_spike_vector use_cache=False still returns the cached vector). Now explicitly calls the USS override and the BaseSorting implementation, and compares the two. - Added test_uss_get_unit_spike_trains_with_renamed_ids: also not a test of the optimization commits per se, but would have caught a mistake made along the way. Verifies get_unit_spike_trains returns child-keyed dicts (not parent-keyed). - Added test_spike_vector_sorted_after_reorder_with_cotemporal_spikes: verifies the USS spike vector is correctly sorted when the selection reverses unit order and co-temporal spikes exist. - Added test_phy_sorting_segment_get_unit_spike_trains: validates the new fast methods on PhySortingSegment. - Added test_phy_compute_and_cache_spike_vector: verifies the Phy override of _compute_and_cache_spike_vector matches BaseSorting implementation.
|
@alejoe91 my changes PR'd to your fork whenever you're ready. The only thing I should point out that isn't in the commit messages: |
Additions to get_unit_spike_trains PR
for more information, see https://pre-commit.ci
|
I am curios on what prompted this? What profiling did you guys do? Any chance that we have a discussion here on the repo at least to know what where the performance benchmarks, reason and validation. |
|
@grahamfindlay is doing very long chronic recordings. He does all the processing and at a second iteration wants to load the phy sorting object, select some units, and get all the spike trains. Just caching the spike vector takes almost 4 minutes! Plus there were some additional lexsort that can be avoided and speed up computation. At least to give some context @h-mayorquin @grahamfindlay maybe you can add some more details on benchmarks and profiling? |
|
Here are example timings for various operations using 1 example subject. This subject only has ~400 million spikes - I have some with many more. FWIW, you shouldn't need long chronic recordings to see tangible improvements from most of these changes. I must dig through notes but tested with 100M spikes and they were still clear gains. "Parent before" = The The two layers of
Comments:
There are rough edges with this PR I know about:
Another question I haven't resolved:
|
| if use_cache: | ||
| # TODO: speed things up | ||
| ordered_spike_vector, slices = self.to_reordered_spike_vector( | ||
| lexsort=("sample_index", "segment_index", "unit_index"), | ||
| return_order=False, | ||
| return_slices=True, | ||
| ) | ||
| unit_indices = self.ids_to_indices(unit_ids) | ||
| spike_trains = {} | ||
| for unit_index, unit_id in zip(unit_indices, unit_ids): | ||
| sl0, sl1 = slices[unit_index, segment_index, :] | ||
| spikes = ordered_spike_vector[sl0:sl1] | ||
| spike_frames = spikes["sample_index"] | ||
| if start_frame is not None: | ||
| start = np.searchsorted(spike_frames, start_frame) | ||
| spike_frames = spike_frames[start:] | ||
| if end_frame is not None: | ||
| end = np.searchsorted(spike_frames, end_frame) | ||
| spike_frames = spike_frames[:end] | ||
| spike_trains[unit_id] = spike_frames | ||
| else: | ||
| spike_trains = segment.get_unit_spike_trains( | ||
| unit_ids=unit_ids, start_frame=start_frame, end_frame=end_frame | ||
| ) |
There was a problem hiding this comment.
I think this is overkill, and should be replaced with something like
spike_trains = {unit_id: self.get_unit_spike_train(unit_id, start_frame=start_frame, end_frame=end_frame, use_cache=use_cache) for unit_id in unit_ids}
return spike_trains
In my local testing, this gives the same speed results. The one thing gain is avoiding repeated time slicing, but I think it's a marginal gain for all this, fairly confusing, code. The spike train code is already fairly complex.
Happy to be proven wrong with benchmarking from a very long recording, but I don't think it's worth doing this until we need it.
|
|
||
| def get_unit_spike_trains( | ||
| self, | ||
| unit_ids: np.ndarray | list, |
There was a problem hiding this comment.
I'd vote to get all spike trains if user doesn't pass unit_ids. Surely almost all user use cases for get_unit_spike_trains is to get all unit spike trains?
There was a problem hiding this comment.
I'm all for that. I actually think that unless we are going to cache the spike trains as such (rather than as a reordered spike vector) -- and I don't think we are [1] -- we should just call the function get_all_spike_trains() and return them all. That would most accurately reflect what the function does. It would make it less surprising for the user that getting 30 spike trains takes the same time as getting 300 spike trains. It would probably also encourage better access patterns. And they can easily filter the dict themselves with a 1-liner like:
spike_trains = {id: train for id, train in sorting.get_all_spike_trains() if id in unit_ids}[1] I think caching both the spike trains and the spike vector would be bad, since the caches could drift out of sync with each other unless care were taken to avoid this, and syncing caches would presumably negate all the benefits of using one representation over another.
| start_frame=start_frame, | ||
| end_frame=end_frame, |
There was a problem hiding this comment.
I was expecting this code to use the times to figure out the start/end frames, and use them here. Instead, this code gets all spike trains then slices. Why?
(EDIT: I'm sure there is a good reason I've not thought of!!)
There was a problem hiding this comment.
I also was confused by this at first, but I think it is because there's no guarantee that the sample returned by BaseRecording.time_to_sample_index() exactly corresponds to the time you give it (it is more like "last frame at or before") so it can behave weirdly when you use to get fetch bounds. For example, if a time vector has samples at [0.0, 0.1, 0.2] and you pass start_time=0.15 to get_unit_spike_trains_in_seconds(), time_to_sample_index(0.15) returns frame 1, but frame 1 has time 0.1 and should be excluded. @alejoe91 can confirm.
You do raise a good point, which is that it seems inefficient to scan the whole train, depending on the underlying representation, and in fact I did implement the bounded scan on PhyKilosortSortingExtractor.get_unit_spike_trains(). Maybe what could be done is, get some conservative frame bounds, use those to fetch the underlying trains, and then do a final mask on the result. Something like:
start_frame = None if start_time is None else first_frame_at_or_after(start_time)
end_frame = None if end_time is None else first_frame_at_or_after(end_time)
spike_frames = self.get_unit_spike_train(..., start_frame=start_frame, end_frame=end_frame)
spike_times = self.sample_index_to_time(spike_frames, ...)
spike_times = spike_times[spike_times >= start_time]
spike_times = spike_times[spike_times < end_time]It seems plausible to me that this could save time.
use_cache(toget_unit_spike_train_in_seconds)to_reordered_spike_vectorselect_units@grahamfindlay
TODO
get_unit_spike_trainsforPhyKilosortSortingExtractor(maybe in follow up)