Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,9 +910,27 @@ def __init__(self, sampling_frequency=None, t_start=None, time_vector=None):
self.sampling_frequency = sampling_frequency
self.t_start = t_start
self.time_vector = time_vector
self._num_channels = None

BaseSegment.__init__(self)

@property
def num_channels(self):
# Return an explicit value if a subclass set one (via the `num_channels` kwarg
# at construction or by assigning `self._num_channels = N`). Otherwise derive from
# the container recording through the weakref established in `add_segment`.
if self._num_channels is not None:
return self._num_channels
if self._parent_extractor is None:
return None
container_recording = self._parent_extractor()
if container_recording is None:
return None
return container_recording.get_num_channels()

def get_num_channels(self):
return self.num_channels

def get_times(self) -> np.ndarray:
if self.time_vector is not None:
self.time_vector = np.asarray(self.time_vector)
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/binaryrecordingextractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __del__(self):
class BinaryRecordingSegment(BaseRecordingSegment):
def __init__(self, file_path, sampling_frequency, t_start, num_channels, dtype, time_axis, file_offset):
BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency, t_start=t_start)
self.num_channels = num_channels
self._num_channels = num_channels
self.dtype = np.dtype(dtype)
self.file_offset = file_offset
self.time_axis = time_axis
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,7 @@ def __init__(
BaseRecordingSegment.__init__(self, sampling_frequency=sampling_frequency)

self.num_samples = num_samples
self.num_channels = num_channels
self._num_channels = num_channels
self.noise_block_size = noise_block_size
self.noise_levels = noise_levels
self.cov_matrix = cov_matrix
Expand Down Expand Up @@ -2075,9 +2075,9 @@ def get_traces(
channel_indices: list | None = None,
) -> np.ndarray:
if channel_indices is None:
n_channels = self.templates.shape[2]
n_channels = self.num_channels
elif isinstance(channel_indices, slice):
stop = channel_indices.stop if channel_indices.stop is not None else self.templates.shape[2]
stop = channel_indices.stop if channel_indices.stop is not None else self.num_channels
start = channel_indices.start if channel_indices.start is not None else 0
step = channel_indices.step if channel_indices.step is not None else 1
n_channels = math.ceil((stop - start) / step)
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/generation/drift_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,9 +542,9 @@ def get_traces(
end_frame = self.num_samples if end_frame is None else end_frame

if channel_indices is None:
n_channels = self.drifting_templates.num_channels
n_channels = self.num_channels
elif isinstance(channel_indices, slice):
stop = channel_indices.stop if channel_indices.stop is not None else self.drifting_templates.num_channels
stop = channel_indices.stop if channel_indices.stop is not None else self.num_channels
start = channel_indices.start if channel_indices.start is not None else 0
step = channel_indices.step if channel_indices.step is not None else 1
n_channels = math.ceil((stop - start) / step)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def __init__(
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
self.parent_recording_segment = parent_recording_segment
self.num_channels = num_channels
self._num_channels = num_channels
self.same_along_dim_chans = same_along_dim_chans
self.n_chans_each_pos = n_chans_each_pos
self._dtype = dtype
Expand Down
8 changes: 2 additions & 6 deletions src/spikeinterface/preprocessing/zero_channel_pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ def __init__(self, recording: BaseRecording, padding_start: int = 0, padding_end
for segment in recording.segments:
recording_segment = TracePaddedRecordingSegment(
segment,
recording.get_num_channels(),
self.dtype,
self.padding_start,
self.padding_end,
Expand All @@ -55,7 +54,6 @@ class TracePaddedRecordingSegment(BasePreprocessorSegment):
def __init__(
self,
recording_segment: BaseRecordingSegment,
num_channels,
dtype,
padding_left,
padding_end,
Expand All @@ -64,7 +62,6 @@ def __init__(
self.padding_start = padding_left
self.padding_end = padding_end
self.fill_value = fill_value
self.num_channels = num_channels
self.num_samples_in_original_segment = recording_segment.get_num_samples()
self.dtype = dtype

Expand Down Expand Up @@ -165,7 +162,7 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping:
self.parent_recording = recording
self.num_channels = num_channels
for segment in recording.segments:
recording_segment = ZeroChannelPaddedRecordingSegment(segment, self.num_channels, self.channel_mapping)
recording_segment = ZeroChannelPaddedRecordingSegment(segment, self.channel_mapping)
self.add_recording_segment(recording_segment)

# only copy relevant metadata and properties
Expand All @@ -182,10 +179,9 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping:


class ZeroChannelPaddedRecordingSegment(BasePreprocessorSegment):
def __init__(self, recording_segment: BaseRecordingSegment, num_channels: int, channel_mapping: list):
def __init__(self, recording_segment: BaseRecordingSegment, channel_mapping: list):
BasePreprocessorSegment.__init__(self, recording_segment)
self.parent_recording_segment = recording_segment
self.num_channels = num_channels
self.channel_mapping = channel_mapping

def get_traces(self, start_frame, end_frame, channel_indices):
Expand Down
Loading