Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,11 @@ def register_recording(self, recording, check_spike_frames: bool = True):
"Might be necessary for further postprocessing."
)
self._recording = recording
# The recording is now the source of truth for timestamps.
# Reset the sorting's own time offset so it doesn't conflict
# with the recording's t_start when accessed through get_start_time/get_end_time.
for segment in self.segments:
segment._t_start = 0

@property
def sorting_info(self):
Expand All @@ -347,6 +352,51 @@ def has_time_vector(self, segment_index: int | None = None) -> bool:
else:
return False

def get_start_time(self, segment_index: int | None = None) -> float:
"""Get the start time of the sorting segment.

If a recording is registered, returns the recording's start time.
Otherwise returns the sorting segment's own t_start (or 0.0).

Parameters
----------
segment_index : int or None, default: None
The segment index (required for multi-segment)

Returns
-------
float
The start time in seconds
"""
segment_index = self._check_segment_index(segment_index)
if self.has_recording():
return self._recording.get_start_time(segment_index=segment_index)
else:
segment = self.segments[segment_index]
return segment._t_start if segment._t_start is not None else 0.0

def get_end_time(self, segment_index: int | None = None) -> float | None:
"""Get the end time of the sorting segment.

If a recording is registered, returns the recording's end time.
Otherwise returns None (the sorting doesn't know the recording duration).

Parameters
----------
segment_index : int or None, default: None
The segment index (required for multi-segment)

Returns
-------
float or None
The end time in seconds, or None if no recording is registered.
"""
segment_index = self._check_segment_index(segment_index)
if self.has_recording():
return self._recording.get_end_time(segment_index=segment_index)
else:
return None

def get_times(self, segment_index=None):
"""
Get time vector for a registered recording segment.
Expand Down
53 changes: 53 additions & 0 deletions src/spikeinterface/core/tests/test_time_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,3 +445,56 @@ def test_shift_times_with_None_as_t_start():
assert recording.segments[0].t_start is None
recording.shift_times(shift=1.0) # Shift by one seconds should not generate an error
assert recording.get_start_time() == 1.0


class TestSortingTimeNoRecording:
"""Tests for time methods on BaseSorting without a registered recording."""

def test_get_start_time_default(self):
sorting = generate_sorting(num_units=5, durations=[10])
assert sorting.get_start_time(segment_index=0) == 0.0

def test_get_end_time_default(self):
sorting = generate_sorting(num_units=5, durations=[10])
assert sorting.get_end_time(segment_index=0) is None

def test_get_start_time_with_t_start(self):
sorting = generate_sorting(num_units=5, durations=[10])
sorting.segments[0]._t_start = 100.0
assert sorting.get_start_time(segment_index=0) == 100.0


class TestSortingTimeWithRecording:
"""
Tests for time methods on BaseSorting with a registered recording.
The key invariant: the recording is the source of truth for timestamps.
"""

def test_get_start_end_time(self):
recording = generate_recording(num_channels=4, durations=[10])
sorting = generate_sorting(num_units=5, durations=[10])
sorting.register_recording(recording)

assert sorting.get_start_time(segment_index=0) == recording.get_start_time(segment_index=0)
assert sorting.get_end_time(segment_index=0) == recording.get_end_time(segment_index=0)

def test_register_recording_resets_t_start(self):
"""Registering a recording resets _t_start so the recording is the sole source of truth."""
sorting = generate_sorting(num_units=5, durations=[10])
sorting.segments[0]._t_start = 100.0

recording = generate_recording(num_channels=4, durations=[10])
sorting.register_recording(recording)

assert sorting.segments[0]._t_start == 0
assert sorting.get_start_time(segment_index=0) == recording.get_start_time(segment_index=0)

def test_with_recording_shifted_start(self):
"""Recording with a non-zero t_start is reflected in the sorting."""
recording = generate_recording(num_channels=4, durations=[10])
recording.shift_times(shift=50.0)

sorting = generate_sorting(num_units=5, durations=[10])
sorting.register_recording(recording)

assert sorting.get_start_time(segment_index=0) == 50.0
Loading