From 5c48ca349f0201ec8ec87d68c358ec1420496648 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Thu, 16 Apr 2026 14:36:57 -0600 Subject: [PATCH] Add time getters for sorting --- src/spikeinterface/core/basesorting.py | 50 +++++++++++++++++ .../core/tests/test_time_handling.py | 53 +++++++++++++++++++ 2 files changed, 103 insertions(+) diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index cb68f3d455..64da35edc9 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -322,6 +322,11 @@ def register_recording(self, recording, check_spike_frames: bool = True): "Might be necessary for further postprocessing." ) self._recording = recording + # The recording is now the source of truth for timestamps. + # Reset the sorting's own time offset so it doesn't conflict + # with the recording's t_start when accessed through get_start_time/get_end_time. + for segment in self.segments: + segment._t_start = 0 @property def sorting_info(self): @@ -347,6 +352,51 @@ def has_time_vector(self, segment_index: int | None = None) -> bool: else: return False + def get_start_time(self, segment_index: int | None = None) -> float: + """Get the start time of the sorting segment. + + If a recording is registered, returns the recording's start time. + Otherwise returns the sorting segment's own t_start (or 0.0). + + Parameters + ---------- + segment_index : int or None, default: None + The segment index (required for multi-segment) + + Returns + ------- + float + The start time in seconds + """ + segment_index = self._check_segment_index(segment_index) + if self.has_recording(): + return self._recording.get_start_time(segment_index=segment_index) + else: + segment = self.segments[segment_index] + return segment._t_start if segment._t_start is not None else 0.0 + + def get_end_time(self, segment_index: int | None = None) -> float | None: + """Get the end time of the sorting segment. + + If a recording is registered, returns the recording's end time. + Otherwise returns None (the sorting doesn't know the recording duration). + + Parameters + ---------- + segment_index : int or None, default: None + The segment index (required for multi-segment) + + Returns + ------- + float or None + The end time in seconds, or None if no recording is registered. + """ + segment_index = self._check_segment_index(segment_index) + if self.has_recording(): + return self._recording.get_end_time(segment_index=segment_index) + else: + return None + def get_times(self, segment_index=None): """ Get time vector for a registered recording segment. diff --git a/src/spikeinterface/core/tests/test_time_handling.py b/src/spikeinterface/core/tests/test_time_handling.py index e03096ce14..50e9f75f7f 100644 --- a/src/spikeinterface/core/tests/test_time_handling.py +++ b/src/spikeinterface/core/tests/test_time_handling.py @@ -445,3 +445,56 @@ def test_shift_times_with_None_as_t_start(): assert recording.segments[0].t_start is None recording.shift_times(shift=1.0) # Shift by one seconds should not generate an error assert recording.get_start_time() == 1.0 + + +class TestSortingTimeNoRecording: + """Tests for time methods on BaseSorting without a registered recording.""" + + def test_get_start_time_default(self): + sorting = generate_sorting(num_units=5, durations=[10]) + assert sorting.get_start_time(segment_index=0) == 0.0 + + def test_get_end_time_default(self): + sorting = generate_sorting(num_units=5, durations=[10]) + assert sorting.get_end_time(segment_index=0) is None + + def test_get_start_time_with_t_start(self): + sorting = generate_sorting(num_units=5, durations=[10]) + sorting.segments[0]._t_start = 100.0 + assert sorting.get_start_time(segment_index=0) == 100.0 + + +class TestSortingTimeWithRecording: + """ + Tests for time methods on BaseSorting with a registered recording. + The key invariant: the recording is the source of truth for timestamps. + """ + + def test_get_start_end_time(self): + recording = generate_recording(num_channels=4, durations=[10]) + sorting = generate_sorting(num_units=5, durations=[10]) + sorting.register_recording(recording) + + assert sorting.get_start_time(segment_index=0) == recording.get_start_time(segment_index=0) + assert sorting.get_end_time(segment_index=0) == recording.get_end_time(segment_index=0) + + def test_register_recording_resets_t_start(self): + """Registering a recording resets _t_start so the recording is the sole source of truth.""" + sorting = generate_sorting(num_units=5, durations=[10]) + sorting.segments[0]._t_start = 100.0 + + recording = generate_recording(num_channels=4, durations=[10]) + sorting.register_recording(recording) + + assert sorting.segments[0]._t_start == 0 + assert sorting.get_start_time(segment_index=0) == recording.get_start_time(segment_index=0) + + def test_with_recording_shifted_start(self): + """Recording with a non-zero t_start is reflected in the sorting.""" + recording = generate_recording(num_channels=4, durations=[10]) + recording.shift_times(shift=50.0) + + sorting = generate_sorting(num_units=5, durations=[10]) + sorting.register_recording(recording) + + assert sorting.get_start_time(segment_index=0) == 50.0