From f88dc2dcafcde75601d30d124c8f661b5622314b Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sat, 18 Apr 2026 14:07:00 -0600 Subject: [PATCH 01/20] First draft --- src/spikeinterface/core/baserecording.py | 4 +- .../core/baserecordingsnippets.py | 121 ++++++------------ src/spikeinterface/core/basesnippets.py | 2 +- .../core/channelsaggregationrecording.py | 15 ++- src/spikeinterface/core/channelslice.py | 64 +++++++-- src/spikeinterface/core/zarrextractors.py | 2 +- .../extractors/neoextractors/biocam.py | 6 +- .../extractors/tests/test_iblextractors.py | 1 - .../tests/test_interpolate_bad_channels.py | 20 ++- .../preprocessing/zero_channel_pad.py | 2 +- .../motion/motion_interpolation.py | 16 ++- 11 files changed, 137 insertions(+), 116 deletions(-) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index f23b524271..2c47693c32 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -637,7 +637,7 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) @@ -665,7 +665,7 @@ def _extra_metadata_from_folder(self, folder): def _extra_metadata_to_folder(self, folder): # save probe - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() write_probeinterface(folder / "probe.json", probegroup) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 58e91ec35c..0b99dfe435 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -19,6 +19,7 @@ def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype BaseExtractor.__init__(self, channel_ids) self._sampling_frequency = float(sampling_frequency) self._dtype = np.dtype(dtype) + self._probegroup = None @property def channel_ids(self): @@ -51,7 +52,7 @@ def has_scaleable_traces(self) -> bool: return True def has_probe(self) -> bool: - return "contact_vector" in self.get_property_keys() + return self._probegroup is not None def has_channel_location(self) -> bool: return self.has_probe() or "location" in self.get_property_keys() @@ -145,24 +146,18 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): probe.device_channel_indices is not None for probe in probegroup.probes ), "Probe must have device_channel_indices" - # this is a vector with complex fileds (dataframe like) that handle all contact attr - probe_as_numpy_array = probegroup.to_numpy(complete=True) - - # keep only connected contact ( != -1) - keep = probe_as_numpy_array["device_channel_indices"] >= 0 - if np.any(~keep): + # identify connected contacts; device_channel_indices values are preserved as provenance + global_device_channel_indices = probegroup.get_global_device_channel_indices()["device_channel_indices"] + connected_mask = global_device_channel_indices >= 0 + if np.any(~connected_mask): warn("The given probes have unconnected contacts: they are removed") - probe_as_numpy_array = probe_as_numpy_array[keep] - - device_channel_indices = probe_as_numpy_array["device_channel_indices"] - order = np.argsort(device_channel_indices) - device_channel_indices = device_channel_indices[order] + device_channel_indices = np.sort(global_device_channel_indices[connected_mask]) - # check TODO: Where did this came from? + # validate indices fit the recording number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) if number_of_device_channel_indices >= self.get_num_channels(): - error_msg = ( + raise ValueError( f"The given Probe either has 'device_channel_indices' that does not match channel count \n" f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" @@ -170,11 +165,13 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): f"device_channel_indices are the following: {device_channel_indices} \n" f"recording channels are the following: {self.get_channel_ids()} \n" ) - raise ValueError(error_msg) new_channel_ids = self.get_channel_ids()[device_channel_indices] - probe_as_numpy_array = probe_as_numpy_array[order] - probe_as_numpy_array["device_channel_indices"] = np.arange(probe_as_numpy_array.size, dtype="int64") + + # drop only the unconnected contacts from the stored probegroup; preserve device_channel_indices values + probegroup = probegroup.get_slice(connected_mask) + probegroup._build_contact_vector() + contact_vector = probegroup.contact_vector # create recording : channel slice or clone or self if in_place: @@ -187,25 +184,18 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): else: sub_recording = self.select_channels(new_channel_ids) - # create a vector that handle all contacts in property - sub_recording.set_property("contact_vector", probe_as_numpy_array, ids=None) + sub_recording._probegroup = probegroup - # planar_contour is saved in annotations - for probe_index, probe in enumerate(probegroup.probes): - contour = probe.probe_planar_contour - if contour is not None: - sub_recording.set_annotation(f"probe_{probe_index}_planar_contour", contour, overwrite=True) - - # duplicate positions to "locations" property + # duplicate positions to "location" property so SpikeInterface-level readers keep working ndim = probegroup.ndim - locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") + locations = np.zeros((contact_vector.size, ndim), dtype="float64") for i, dim in enumerate(["x", "y", "z"][:ndim]): - locations[:, i] = probe_as_numpy_array[dim] + locations[:, i] = contact_vector[dim] sub_recording.set_property("location", locations, ids=None) - # handle groups - has_shank_id = "shank_ids" in probe_as_numpy_array.dtype.fields - has_contact_side = "contact_sides" in probe_as_numpy_array.dtype.fields + # derive groups from contact_vector + has_shank_id = "shank_ids" in contact_vector.dtype.fields + has_contact_side = "contact_sides" in contact_vector.dtype.fields if group_mode == "auto": group_keys = ["probe_index"] if has_shank_id: @@ -223,21 +213,15 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): group_keys = ["probe_index", "shank_ids", "contact_sides"] else: group_keys = ["probe_index", "contact_sides"] - groups = np.zeros(probe_as_numpy_array.size, dtype="int64") - unique_keys = np.unique(probe_as_numpy_array[group_keys]) + groups = np.zeros(contact_vector.size, dtype="int64") + unique_keys = np.unique(contact_vector[group_keys]) for group, a in enumerate(unique_keys): - mask = np.ones(probe_as_numpy_array.size, dtype=bool) + mask = np.ones(contact_vector.size, dtype=bool) for k in group_keys: - mask &= probe_as_numpy_array[k] == a[k] + mask &= contact_vector[k] == a[k] groups[mask] = group sub_recording.set_property("group", groups, ids=None) - # add probe annotations to recording - probes_info = [] - for probe in probegroup.probes: - probes_info.append(probe.annotations) - sub_recording.annotate(probes_info=probes_info) - return sub_recording def get_probe(self): @@ -250,30 +234,9 @@ def get_probes(self): return probegroup.probes def get_probegroup(self): - arr = self.get_property("contact_vector") - if arr is None: - positions = self.get_property("location") - if positions is None: - raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") - else: - warn("There is no Probe attached to this recording. Creating a dummy one with contact positions") - probe = self.create_dummy_probe_from_locations(positions) - # probe.create_auto_shape() - probegroup = ProbeGroup() - probegroup.add_probe(probe) - else: - probegroup = ProbeGroup.from_numpy(arr) - - if "probes_info" in self.get_annotation_keys(): - probes_info = self.get_annotation("probes_info") - for probe, probe_info in zip(probegroup.probes, probes_info): - probe.annotations = probe_info - - for probe_index, probe in enumerate(probegroup.probes): - contour = self.get_annotation(f"probe_{probe_index}_planar_contour") - if contour is not None: - probe.set_planar_contour(contour) - return probegroup + if self._probegroup is None: + raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") + return self._probegroup def _extra_metadata_from_folder(self, folder): # load probe @@ -284,7 +247,7 @@ def _extra_metadata_from_folder(self, folder): def _extra_metadata_to_folder(self, folder): # save probe - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() write_probeinterface(folder / "probe.json", probegroup) @@ -341,7 +304,7 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params self.set_probe(probe, in_place=True) def set_channel_locations(self, locations, channel_ids=None): - if self.get_property("contact_vector") is not None: + if self.has_probe(): raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)") self.set_property("location", locations, ids=channel_ids) @@ -349,21 +312,15 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarra if channel_ids is None: channel_ids = self.get_channel_ids() channel_indices = self.ids_to_indices(channel_ids) - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - # here we bypass the probe reconstruction so this works both for probe and probegroup - ndim = len(axes) - all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") - for i, dim in enumerate(axes): - all_positions[:, i] = contact_vector[dim] - positions = all_positions[channel_indices] - return positions - else: - locations = self.get_property("location") - if locations is None: - raise Exception("There are no channel locations") - locations = np.asarray(locations)[channel_indices] - return select_axes(locations, axes) + if not self.has_probe(): + raise ValueError("get_channel_locations(..) needs a probe to be attached to the recording") + self._probegroup._build_contact_vector() + contact_vector = self._probegroup.contact_vector + ndim = len(axes) + all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") + for i, dim in enumerate(axes): + all_positions[:, i] = contact_vector[dim] + return all_positions[channel_indices] def has_3d_locations(self) -> bool: return self.get_property("location").shape[1] == 3 diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index b56a093ccc..fa47365200 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -259,7 +259,7 @@ def _save(self, format="npy", **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 697aab875e..59501d0ba1 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -2,6 +2,8 @@ import numpy as np +from probeinterface import ProbeGroup + from .baserecording import BaseRecording, BaseRecordingSegment @@ -90,11 +92,18 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record break for prop_name, prop_values in property_dict.items(): - if prop_name == "contact_vector": - # remap device channel indices correctly - prop_values["device_channel_indices"] = np.arange(self.get_num_channels()) self.set_property(key=prop_name, values=prop_values) + # aggregate probegroups across the inputs and reset wiring to the new channel order + if all(rec.has_probe() for rec in recording_list): + aggregated_probegroup = ProbeGroup() + for rec in recording_list: + for probe in rec.get_probegroup().probes: + aggregated_probegroup.add_probe(probe.copy()) + aggregated_probegroup.set_global_device_channel_indices(np.arange(self.get_num_channels(), dtype="int64")) + aggregated_probegroup._build_contact_vector() + self._probegroup = aggregated_probegroup + # if locations are present, check that they are all different! if "location" in self.get_property_keys(): location_tuple = [tuple(loc) for loc in self.get_property("location")] diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index de693d5c26..82e842dabc 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -61,11 +61,33 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) parent_recording.copy_metadata(self, only_main=False, ids=self._channel_ids) self._parent = parent_recording - # change the wiring of the probe - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + # filter the probegroup to contacts wired to the retained channels + if parent_recording.has_probe(): + parent_probegroup = parent_recording.get_probegroup() + parent_dci_sorted = np.sort(parent_probegroup.get_global_device_channel_indices()["device_channel_indices"]) + child_dci_values = parent_dci_sorted[self._parent_channel_indices] + are_channels_reordered: bool = not np.all(np.diff(child_dci_values) >= 0) + probe_dci = parent_probegroup.get_global_device_channel_indices()["device_channel_indices"] + keep_mask = np.isin(probe_dci, child_dci_values) + sliced_probegroup = parent_probegroup.get_slice(keep_mask) + + if not are_channels_reordered: + # simple case: the child's channels are already in ascending device_channel_indices order + # so _build_contact_vector on the filtered probegroup will produce rows in the child's + # channel order. Nothing else to do. + pass + else: + # reorder case: the user picked channels in an order that does not match sort-by-dci. + # We have to rewrite device_channel_indices on the child's copy so that the sort done + # by _build_contact_vector aligns with the child's channel order. + new_dci_by_old = {int(d): new for new, d in enumerate(child_dci_values.tolist())} + sliced_dci = sliced_probegroup.get_global_device_channel_indices()["device_channel_indices"] + sliced_probegroup.set_global_device_channel_indices( + np.array([new_dci_by_old[int(d)] for d in sliced_dci], dtype="int64") + ) + + sliced_probegroup._build_contact_vector() + self._probegroup = sliced_probegroup # update dump dict self._kwargs = { @@ -151,11 +173,33 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): # copy annotation and properties parent_snippets.copy_metadata(self, only_main=False, ids=self._channel_ids) - # change the wiring of the probe - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + # filter the probegroup to contacts wired to the retained channels + if parent_snippets.has_probe(): + parent_probegroup = parent_snippets.get_probegroup() + parent_dci_sorted = np.sort(parent_probegroup.get_global_device_channel_indices()["device_channel_indices"]) + child_dci_values = parent_dci_sorted[self._parent_channel_indices] + are_channels_reordered: bool = not np.all(np.diff(child_dci_values) >= 0) + probe_dci = parent_probegroup.get_global_device_channel_indices()["device_channel_indices"] + keep_mask = np.isin(probe_dci, child_dci_values) + sliced_probegroup = parent_probegroup.get_slice(keep_mask) + + if not are_channels_reordered: + # simple case: the child's channels are already in ascending device_channel_indices order + # so _build_contact_vector on the filtered probegroup will produce rows in the child's + # channel order. Nothing else to do. + pass + else: + # reorder case: the user picked channels in an order that does not match sort-by-dci. + # We have to rewrite device_channel_indices on the child's copy so that the sort done + # by _build_contact_vector aligns with the child's channel order. + new_dci_by_old = {int(d): new for new, d in enumerate(child_dci_values.tolist())} + sliced_dci = sliced_probegroup.get_global_device_channel_indices()["device_channel_indices"] + sliced_probegroup.set_global_device_channel_indices( + np.array([new_dci_by_old[int(d)] for d in sliced_dci], dtype="int64") + ) + + sliced_probegroup._build_contact_vector() + self._probegroup = sliced_probegroup # update dump dict self._kwargs = { diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index 1ef5d76e5a..941a99b877 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -503,7 +503,7 @@ def add_recording_to_zarr_group( ) # save probe - if recording.get_property("contact_vector") is not None: + if recording.has_probe(): probegroup = recording.get_probegroup() zarr_group.attrs["probe"] = check_json(probegroup.to_dict(array_as_list=True)) diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 8d1fac0c72..c2ceb66523 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -1,5 +1,6 @@ from pathlib import Path +import numpy as np import probeinterface from spikeinterface.core.core_tools import define_function_from_class @@ -71,8 +72,9 @@ def __init__( probe_kwargs["electrode_width"] = electrode_width probe = probeinterface.read_3brain(file_path, **probe_kwargs) self.set_probe(probe, in_place=True) - self.set_property("row", self.get_property("contact_vector")["row"]) - self.set_property("col", self.get_property("contact_vector")["col"]) + probe = self.get_probegroup().probes[0] + self.set_property("row", np.asarray(probe.contact_annotations["row"])) + self.set_property("col", np.asarray(probe.contact_annotations["col"])) self._kwargs.update( { diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 5306de2441..ff21c7a3c7 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -84,7 +84,6 @@ def test_property_keys(self): expected_property_keys = [ "gain_to_uV", "offset_to_uV", - "contact_vector", "location", "group", "shank", diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index a571894374..6b40548bc4 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -126,8 +126,10 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan # distribute default probe locations across 4 shanks if set rng = np.random.default_rng(seed=None) x = rng.choice(shanks, num_channels) - for idx, __ in enumerate(recording._properties["contact_vector"]): - recording._properties["contact_vector"][idx][1] = x[idx] + probe = recording.get_probegroup().probes[0] + probe._contact_positions[:, 0] = x + recording._probegroup._build_contact_vector() + recording.set_property("location", recording.get_channel_locations()) # generate random bad channel locations bad_channel_indexes = rng.choice(num_channels, rng.integers(1, int(num_channels / 5)), replace=False) @@ -170,9 +172,12 @@ def test_output_values(): [5, 5, 5, 7, 3], ] # all others equal distance away. # Overwrite the probe information with the new locations + probe = recording.get_probegroup().probes[0] for idx, (x, y) in enumerate(zip(*new_probe_locs)): - recording._properties["contact_vector"][idx][1] = x - recording._properties["contact_vector"][idx][2] = y + probe._contact_positions[idx, 0] = x + probe._contact_positions[idx, 1] = y + recording._probegroup._build_contact_vector() + recording.set_property("location", recording.get_channel_locations()) # Run interpolation in SI and check the interpolated channel # 0 is a linear combination of other channels @@ -186,8 +191,11 @@ def test_output_values(): # Shift the last channel position so that it is 4 units, rather than 2 # away. Setting sigma_um = p = 1 allows easy calculation of the expected # weights. - recording._properties["contact_vector"][-1][1] = 5 - recording._properties["contact_vector"][-1][2] = 9 + probe = recording.get_probegroup().probes[0] + probe._contact_positions[-1, 0] = 5 + probe._contact_positions[-1, 1] = 9 + recording._probegroup._build_contact_vector() + recording.set_property("location", recording.get_channel_locations()) expected_weights = np.r_[np.tile(np.exp(-2), 3), np.exp(-4)] expected_weights /= np.sum(expected_weights) diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index 45d4809cd8..a84d3bbf64 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -157,7 +157,7 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: "The new mapping cannot exceed total number of channels " "in the zero-chanenl-padded recording." ) else: - if "locations" in recording.get_property_keys() or "contact_vector" in recording.get_property_keys(): + if recording.has_probe() or "location" in recording.get_property_keys(): self.channel_mapping = np.argsort(recording.get_channel_locations()[:, 1]) else: self.channel_mapping = np.arange(recording.get_num_channels()) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 7c4c4b166e..027f2bbb66 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -403,13 +403,15 @@ def __init__( dtype_ = fix_dtype(recording, dtype) BasePreprocessor.__init__(self, recording, channel_ids=channel_ids, dtype=dtype_) - if border_mode == "remove_channels": - # change the wiring of the probe - # TODO this is also done in ChannelSliceRecording, this should be done in a common place - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + if border_mode == "remove_channels" and recording.has_probe(): + # filter the probegroup to contacts wired to the retained channels; order is preserved (channel_inds is ascending) + parent_probegroup = recording.get_probegroup() + parent_dci_sorted = np.sort(parent_probegroup.get_global_device_channel_indices()["device_channel_indices"]) + child_dci_values = parent_dci_sorted[channel_inds] + probe_dci = parent_probegroup.get_global_device_channel_indices()["device_channel_indices"] + sliced_probegroup = parent_probegroup.get_slice(np.isin(probe_dci, child_dci_values)) + sliced_probegroup._build_contact_vector() + self._probegroup = sliced_probegroup # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below From cd52ef333496f638e0fd7115c50625bec4ef403c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sat, 18 Apr 2026 14:10:47 -0600 Subject: [PATCH 02/20] second iteration --- .../core/baserecordingsnippets.py | 13 +++-- src/spikeinterface/core/channelslice.py | 52 +++---------------- .../motion/motion_interpolation.py | 8 ++- 3 files changed, 18 insertions(+), 55 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 0b99dfe435..b8de26df3d 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -146,13 +146,17 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): probe.device_channel_indices is not None for probe in probegroup.probes ), "Probe must have device_channel_indices" - # identify connected contacts; device_channel_indices values are preserved as provenance + # identify connected contacts and their channel-order global_device_channel_indices = probegroup.get_global_device_channel_indices()["device_channel_indices"] connected_mask = global_device_channel_indices >= 0 if np.any(~connected_mask): warn("The given probes have unconnected contacts: they are removed") - device_channel_indices = np.sort(global_device_channel_indices[connected_mask]) + connected_contact_indices = np.where(connected_mask)[0] + connected_channel_values = global_device_channel_indices[connected_mask] + order = np.argsort(connected_channel_values) + sorted_contact_indices = connected_contact_indices[order] + device_channel_indices = connected_channel_values[order] # validate indices fit the recording number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) @@ -168,8 +172,9 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): new_channel_ids = self.get_channel_ids()[device_channel_indices] - # drop only the unconnected contacts from the stored probegroup; preserve device_channel_indices values - probegroup = probegroup.get_slice(connected_mask) + # slice + reorder probegroup so contact order matches the recording's channel order, and reset wiring to arange + probegroup = probegroup.get_slice(sorted_contact_indices) + probegroup.set_global_device_channel_indices(np.arange(len(device_channel_indices), dtype="int64")) probegroup._build_contact_vector() contact_vector = probegroup.contact_vector diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 82e842dabc..6491113884 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -61,31 +61,11 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) parent_recording.copy_metadata(self, only_main=False, ids=self._channel_ids) self._parent = parent_recording - # filter the probegroup to contacts wired to the retained channels + # slice the probegroup to the retained channels and reset wiring to the new channel order if parent_recording.has_probe(): parent_probegroup = parent_recording.get_probegroup() - parent_dci_sorted = np.sort(parent_probegroup.get_global_device_channel_indices()["device_channel_indices"]) - child_dci_values = parent_dci_sorted[self._parent_channel_indices] - are_channels_reordered: bool = not np.all(np.diff(child_dci_values) >= 0) - probe_dci = parent_probegroup.get_global_device_channel_indices()["device_channel_indices"] - keep_mask = np.isin(probe_dci, child_dci_values) - sliced_probegroup = parent_probegroup.get_slice(keep_mask) - - if not are_channels_reordered: - # simple case: the child's channels are already in ascending device_channel_indices order - # so _build_contact_vector on the filtered probegroup will produce rows in the child's - # channel order. Nothing else to do. - pass - else: - # reorder case: the user picked channels in an order that does not match sort-by-dci. - # We have to rewrite device_channel_indices on the child's copy so that the sort done - # by _build_contact_vector aligns with the child's channel order. - new_dci_by_old = {int(d): new for new, d in enumerate(child_dci_values.tolist())} - sliced_dci = sliced_probegroup.get_global_device_channel_indices()["device_channel_indices"] - sliced_probegroup.set_global_device_channel_indices( - np.array([new_dci_by_old[int(d)] for d in sliced_dci], dtype="int64") - ) - + sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) + sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) sliced_probegroup._build_contact_vector() self._probegroup = sliced_probegroup @@ -173,31 +153,11 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): # copy annotation and properties parent_snippets.copy_metadata(self, only_main=False, ids=self._channel_ids) - # filter the probegroup to contacts wired to the retained channels + # slice the probegroup to the retained channels and reset wiring to the new channel order if parent_snippets.has_probe(): parent_probegroup = parent_snippets.get_probegroup() - parent_dci_sorted = np.sort(parent_probegroup.get_global_device_channel_indices()["device_channel_indices"]) - child_dci_values = parent_dci_sorted[self._parent_channel_indices] - are_channels_reordered: bool = not np.all(np.diff(child_dci_values) >= 0) - probe_dci = parent_probegroup.get_global_device_channel_indices()["device_channel_indices"] - keep_mask = np.isin(probe_dci, child_dci_values) - sliced_probegroup = parent_probegroup.get_slice(keep_mask) - - if not are_channels_reordered: - # simple case: the child's channels are already in ascending device_channel_indices order - # so _build_contact_vector on the filtered probegroup will produce rows in the child's - # channel order. Nothing else to do. - pass - else: - # reorder case: the user picked channels in an order that does not match sort-by-dci. - # We have to rewrite device_channel_indices on the child's copy so that the sort done - # by _build_contact_vector aligns with the child's channel order. - new_dci_by_old = {int(d): new for new, d in enumerate(child_dci_values.tolist())} - sliced_dci = sliced_probegroup.get_global_device_channel_indices()["device_channel_indices"] - sliced_probegroup.set_global_device_channel_indices( - np.array([new_dci_by_old[int(d)] for d in sliced_dci], dtype="int64") - ) - + sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) + sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) sliced_probegroup._build_contact_vector() self._probegroup = sliced_probegroup diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 027f2bbb66..612b667f63 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -404,12 +404,10 @@ def __init__( BasePreprocessor.__init__(self, recording, channel_ids=channel_ids, dtype=dtype_) if border_mode == "remove_channels" and recording.has_probe(): - # filter the probegroup to contacts wired to the retained channels; order is preserved (channel_inds is ascending) + # slice the probegroup to the retained channels and reset wiring to the new channel order parent_probegroup = recording.get_probegroup() - parent_dci_sorted = np.sort(parent_probegroup.get_global_device_channel_indices()["device_channel_indices"]) - child_dci_values = parent_dci_sorted[channel_inds] - probe_dci = parent_probegroup.get_global_device_channel_indices()["device_channel_indices"] - sliced_probegroup = parent_probegroup.get_slice(np.isin(probe_dci, child_dci_values)) + sliced_probegroup = parent_probegroup.get_slice(channel_inds) + sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) sliced_probegroup._build_contact_vector() self._probegroup = sliced_probegroup From f52c39c81eeaaf1095727795ce47b6191ba4a600 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 11:00:55 -0600 Subject: [PATCH 03/20] using cache and consistent use of set_probe_groups --- src/spikeinterface/core/base.py | 12 ++++++++++++ src/spikeinterface/core/baserecordingsnippets.py | 5 ++--- .../core/channelsaggregationrecording.py | 11 ++++++----- src/spikeinterface/core/channelslice.py | 2 -- .../extractors/neoextractors/maxwell.py | 3 ++- .../sortingcomponents/motion/motion_interpolation.py | 1 - 6 files changed, 22 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 8d149a7c49..3bac073de5 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -574,6 +574,9 @@ def to_dict( folder_metadata = Path(folder_metadata).resolve().absolute().relative_to(relative_to) dump_dict["folder_metadata"] = str(folder_metadata) + if getattr(self, "_probegroup", None) is not None: + dump_dict["probegroup"] = self._probegroup.to_dict(array_as_list=True) + return dump_dict @staticmethod @@ -1161,6 +1164,15 @@ def _load_extractor_from_dict(dic) -> "BaseExtractor": for k, v in dic["properties"].items(): extractor.set_property(k, v) + if "probegroup" in dic: + from probeinterface import ProbeGroup + + probegroup = ProbeGroup.from_dict(dic["probegroup"]) + if hasattr(extractor, "set_probegroup"): + extractor.set_probegroup(probegroup, in_place=True) + else: + extractor._probegroup = probegroup + return extractor diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index b8de26df3d..83ef0e1699 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -161,7 +161,7 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): # validate indices fit the recording number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) if number_of_device_channel_indices >= self.get_num_channels(): - raise ValueError( + error_msg = ( f"The given Probe either has 'device_channel_indices' that does not match channel count \n" f"{len(device_channel_indices)} vs {self.get_num_channels()} \n" f"or it's max index {number_of_device_channel_indices} is the same as the number of channels {self.get_num_channels()} \n" @@ -169,13 +169,13 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): f"device_channel_indices are the following: {device_channel_indices} \n" f"recording channels are the following: {self.get_channel_ids()} \n" ) + raise ValueError(error_msg) new_channel_ids = self.get_channel_ids()[device_channel_indices] # slice + reorder probegroup so contact order matches the recording's channel order, and reset wiring to arange probegroup = probegroup.get_slice(sorted_contact_indices) probegroup.set_global_device_channel_indices(np.arange(len(device_channel_indices), dtype="int64")) - probegroup._build_contact_vector() contact_vector = probegroup.contact_vector # create recording : channel slice or clone or self @@ -319,7 +319,6 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarra channel_indices = self.ids_to_indices(channel_ids) if not self.has_probe(): raise ValueError("get_channel_locations(..) needs a probe to be attached to the recording") - self._probegroup._build_contact_vector() contact_vector = self._probegroup.contact_vector ndim = len(axes) all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 59501d0ba1..2e5deb8703 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -2,7 +2,7 @@ import numpy as np -from probeinterface import ProbeGroup +from probeinterface import Probe, ProbeGroup from .baserecording import BaseRecording, BaseRecordingSegment @@ -94,15 +94,16 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record for prop_name, prop_values in property_dict.items(): self.set_property(key=prop_name, values=prop_values) - # aggregate probegroups across the inputs and reset wiring to the new channel order + # aggregate probegroups across the inputs and attach via the canonical path if all(rec.has_probe() for rec in recording_list): aggregated_probegroup = ProbeGroup() for rec in recording_list: for probe in rec.get_probegroup().probes: - aggregated_probegroup.add_probe(probe.copy()) + # round-trip through to_dict/from_dict because Probe.copy() drops contact_ids + # and annotations (tracked in probeinterface #421) + aggregated_probegroup.add_probe(Probe.from_dict(probe.to_dict(array_as_list=False))) aggregated_probegroup.set_global_device_channel_indices(np.arange(self.get_num_channels(), dtype="int64")) - aggregated_probegroup._build_contact_vector() - self._probegroup = aggregated_probegroup + self.set_probegroup(aggregated_probegroup, in_place=True) # if locations are present, check that they are all different! if "location" in self.get_property_keys(): diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 6491113884..5687001340 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -66,7 +66,6 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) parent_probegroup = parent_recording.get_probegroup() sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) - sliced_probegroup._build_contact_vector() self._probegroup = sliced_probegroup # update dump dict @@ -158,7 +157,6 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): parent_probegroup = parent_snippets.get_probegroup() sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) - sliced_probegroup._build_contact_vector() self._probegroup = sliced_probegroup # update dump dict diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 932ecee106..875422d00b 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -75,7 +75,8 @@ def __init__( rec_name = self.neo_reader.rec_name probe = probeinterface.read_maxwell(file_path, well_name=well_name, rec_name=rec_name) self.set_probe(probe, in_place=True) - self.set_property("electrode", self.get_property("contact_vector")["electrode"]) + probe = self.get_probegroup().probes[0] + self.set_property("electrode", np.asarray(probe.contact_annotations["electrode"])) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), rec_name=rec_name)) @classmethod diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 612b667f63..0bb3f8b065 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -408,7 +408,6 @@ def __init__( parent_probegroup = recording.get_probegroup() sliced_probegroup = parent_probegroup.get_slice(channel_inds) sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) - sliced_probegroup._build_contact_vector() self._probegroup = sliced_probegroup # handle manual interpolation_time_bin_centers_s From 594122cda4d98a60c6341bfe40118242bfdce04d Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 11:01:30 -0600 Subject: [PATCH 04/20] using cache and consistent use of set_probe_groups --- src/spikeinterface/core/channelslice.py | 8 ++++---- .../sortingcomponents/motion/motion_interpolation.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index 5687001340..f7d498db04 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -61,12 +61,12 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) parent_recording.copy_metadata(self, only_main=False, ids=self._channel_ids) self._parent = parent_recording - # slice the probegroup to the retained channels and reset wiring to the new channel order + # slice the probegroup to the retained channels and attach via the canonical path if parent_recording.has_probe(): parent_probegroup = parent_recording.get_probegroup() sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) - self._probegroup = sliced_probegroup + self.set_probegroup(sliced_probegroup, in_place=True) # update dump dict self._kwargs = { @@ -152,12 +152,12 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): # copy annotation and properties parent_snippets.copy_metadata(self, only_main=False, ids=self._channel_ids) - # slice the probegroup to the retained channels and reset wiring to the new channel order + # slice the probegroup to the retained channels and attach via the canonical path if parent_snippets.has_probe(): parent_probegroup = parent_snippets.get_probegroup() sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) - self._probegroup = sliced_probegroup + self.set_probegroup(sliced_probegroup, in_place=True) # update dump dict self._kwargs = { diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index 0bb3f8b065..b88987f884 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -404,11 +404,11 @@ def __init__( BasePreprocessor.__init__(self, recording, channel_ids=channel_ids, dtype=dtype_) if border_mode == "remove_channels" and recording.has_probe(): - # slice the probegroup to the retained channels and reset wiring to the new channel order + # slice the probegroup to the retained channels and attach via the canonical path parent_probegroup = recording.get_probegroup() sliced_probegroup = parent_probegroup.get_slice(channel_inds) sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) - self._probegroup = sliced_probegroup + self.set_probegroup(sliced_probegroup, in_place=True) # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below From 8bdcdc9ab4118d4c000a41993969b0662773d35c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 13:27:00 -0600 Subject: [PATCH 05/20] recover cophy semantics --- src/spikeinterface/core/base.py | 36 +++++++++++++++++++ .../core/baserecordingsnippets.py | 4 +-- 2 files changed, 38 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index 3bac073de5..a05dbdb566 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -1172,10 +1172,46 @@ def _load_extractor_from_dict(dic) -> "BaseExtractor": extractor.set_probegroup(probegroup, in_place=True) else: extractor._probegroup = probegroup + elif "contact_vector" in dic.get("properties", {}): + _restore_probegroup_from_legacy_contact_vector(extractor) return extractor +def _restore_probegroup_from_legacy_contact_vector(extractor) -> None: + """ + Reconstruct a `ProbeGroup` from the legacy `contact_vector` property. + + Recordings saved before the probegroup refactor stored the probe as a structured numpy + array under the `contact_vector` property, with probe-level annotations under a separate + `probes_info` annotation and per-probe planar contours under `probe_{i}_planar_contour` + annotations. This function reconstructs a `ProbeGroup` from those legacy fields, attaches + it via the canonical `set_probegroup` path, and removes the legacy property so the new + and old representations do not coexist on the loaded extractor. + """ + from probeinterface import ProbeGroup + + contact_vector_array = extractor.get_property("contact_vector") + probegroup = ProbeGroup.from_numpy(contact_vector_array) + + if "probes_info" in extractor.get_annotation_keys(): + probes_info = extractor.get_annotation("probes_info") + for probe, probe_info in zip(probegroup.probes, probes_info): + probe.annotations = probe_info + + for probe_index, probe in enumerate(probegroup.probes): + contour = extractor._annotations.get(f"probe_{probe_index}_planar_contour") + if contour is not None: + probe.set_planar_contour(contour) + + if hasattr(extractor, "set_probegroup"): + extractor.set_probegroup(probegroup, in_place=True) + else: + extractor._probegroup = probegroup + + extractor._properties.pop("contact_vector", None) + + def _get_class_from_string(class_string): class_name = class_string.split(".")[-1] module = ".".join(class_string.split(".")[:-1]) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 83ef0e1699..ccc798dbdc 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -176,7 +176,7 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): # slice + reorder probegroup so contact order matches the recording's channel order, and reset wiring to arange probegroup = probegroup.get_slice(sorted_contact_indices) probegroup.set_global_device_channel_indices(np.arange(len(device_channel_indices), dtype="int64")) - contact_vector = probegroup.contact_vector + contact_vector = probegroup._contact_vector # create recording : channel slice or clone or self if in_place: @@ -319,7 +319,7 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarra channel_indices = self.ids_to_indices(channel_ids) if not self.has_probe(): raise ValueError("get_channel_locations(..) needs a probe to be attached to the recording") - contact_vector = self._probegroup.contact_vector + contact_vector = self._probegroup._contact_vector ndim = len(axes) all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") for i, dim in enumerate(axes): From d499bc0d23194430bb6d158e6e9a0c3e462ef78b Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 16:08:40 -0600 Subject: [PATCH 06/20] add docstring --- .../core/baserecordingsnippets.py | 44 ++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index ccc798dbdc..8df1c68bba 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -1,3 +1,4 @@ +import copy from pathlib import Path import numpy as np @@ -230,18 +231,38 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): return sub_recording def get_probe(self): + """ + Return a copy of the single attached probe. + + Returns a deepcopy so callers can mutate the probe without affecting the + recording's internal state. To re-attach a mutated probe use `set_probe(...)`. + """ probes = self.get_probes() assert len(probes) == 1, "there are several probe use .get_probes() or get_probegroup()" return probes[0] def get_probes(self): + """ + Return a list of copies of the attached probes. + + Returns deepcopies so callers can mutate probes without affecting the + recording's internal state. To re-attach a mutated probe use + `set_probegroup(...)` or `set_probe(...)`. + """ probegroup = self.get_probegroup() return probegroup.probes def get_probegroup(self): + """ + Return a copy of the attached `ProbeGroup`. + + Returns a deepcopy so callers hold a snapshot independent of the recording's + internal state. Mutating the returned probegroup does not modify the + recording; to commit changes use `set_probegroup(...)`. + """ if self._probegroup is None: raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") - return self._probegroup + return copy.deepcopy(self._probegroup) def _extra_metadata_from_folder(self, folder): # load probe @@ -309,6 +330,14 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params self.set_probe(probe, in_place=True) def set_channel_locations(self, locations, channel_ids=None): + """ + Set channel locations directly on the `"location"` property. + + When a probe is attached, channel locations come from the probegroup and + `"location"` is a compatibility mirror maintained by `_set_probes`. Writing + directly to the property would diverge the mirror from the probegroup, so + this method raises in that case; reattach a modified probe via `set_probe`. + """ if self.has_probe(): raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)") self.set_property("location", locations, ids=channel_ids) @@ -338,12 +367,25 @@ def clear_channel_locations(self, channel_ids=None): self.set_property("location", locations, ids=channel_ids) def set_channel_groups(self, groups, channel_ids=None): + """ + Set channel groups directly on the `"group"` property. + + When a probe is attached, the `"group"` property is a compatibility mirror + derived by `_set_probes` from the probegroup and the chosen `group_mode`. + Writing groups directly bypasses that derivation and can diverge from the + probegroup; prefer re-attaching via `set_probe(..., group_mode=...)`. + """ if "probes" in self._annotations: warn("set_channel_groups() destroys the probe description. Using set_probe() is preferable") self._annotations.pop("probes") self.set_property("group", groups, ids=channel_ids) def get_channel_groups(self, channel_ids=None): + # Note: `"group"` is a compatibility mirror of the probegroup-derived grouping + # when a probe is attached, populated at `_set_probes` time. It is read directly + # here because the `group_mode` used to derive it is not currently persisted on + # the recording. Follow-up work may unify this with `get_channel_locations` by + # reading directly from the attached probegroup. groups = self.get_property("group", ids=channel_ids) return groups From f15e5c6514f53e09d7e6a6bd27e9957fb6e99b0b Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 16:20:57 -0600 Subject: [PATCH 07/20] just copies --- .../core/baserecordingsnippets.py | 24 ++++--------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 8df1c68bba..268585ff3a 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -231,37 +231,21 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): return sub_recording def get_probe(self): - """ - Return a copy of the single attached probe. - - Returns a deepcopy so callers can mutate the probe without affecting the - recording's internal state. To re-attach a mutated probe use `set_probe(...)`. - """ probes = self.get_probes() assert len(probes) == 1, "there are several probe use .get_probes() or get_probegroup()" return probes[0] def get_probes(self): - """ - Return a list of copies of the attached probes. - - Returns deepcopies so callers can mutate probes without affecting the - recording's internal state. To re-attach a mutated probe use - `set_probegroup(...)` or `set_probe(...)`. - """ probegroup = self.get_probegroup() return probegroup.probes def get_probegroup(self): - """ - Return a copy of the attached `ProbeGroup`. - - Returns a deepcopy so callers hold a snapshot independent of the recording's - internal state. Mutating the returned probegroup does not modify the - recording; to commit changes use `set_probegroup(...)`. - """ if self._probegroup is None: raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") + # Return a deepcopy for backwards compatibility: pre-migration `main` reconstructed + # a fresh `ProbeGroup` from the stored structured array on each call, so external + # callers relied on value semantics. Handing out the live `_probegroup` would be a + # silent behavioural change. return copy.deepcopy(self._probegroup) def _extra_metadata_from_folder(self, folder): From e1b71032b5e1dece03ea1566d8fddf453f480bed Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 16:22:49 -0600 Subject: [PATCH 08/20] remove comments --- .../core/baserecordingsnippets.py | 21 ------------------- 1 file changed, 21 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 268585ff3a..b12a753ae9 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -314,14 +314,6 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params self.set_probe(probe, in_place=True) def set_channel_locations(self, locations, channel_ids=None): - """ - Set channel locations directly on the `"location"` property. - - When a probe is attached, channel locations come from the probegroup and - `"location"` is a compatibility mirror maintained by `_set_probes`. Writing - directly to the property would diverge the mirror from the probegroup, so - this method raises in that case; reattach a modified probe via `set_probe`. - """ if self.has_probe(): raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)") self.set_property("location", locations, ids=channel_ids) @@ -351,25 +343,12 @@ def clear_channel_locations(self, channel_ids=None): self.set_property("location", locations, ids=channel_ids) def set_channel_groups(self, groups, channel_ids=None): - """ - Set channel groups directly on the `"group"` property. - - When a probe is attached, the `"group"` property is a compatibility mirror - derived by `_set_probes` from the probegroup and the chosen `group_mode`. - Writing groups directly bypasses that derivation and can diverge from the - probegroup; prefer re-attaching via `set_probe(..., group_mode=...)`. - """ if "probes" in self._annotations: warn("set_channel_groups() destroys the probe description. Using set_probe() is preferable") self._annotations.pop("probes") self.set_property("group", groups, ids=channel_ids) def get_channel_groups(self, channel_ids=None): - # Note: `"group"` is a compatibility mirror of the probegroup-derived grouping - # when a probe is attached, populated at `_set_probes` time. It is read directly - # here because the `group_mode` used to derive it is not currently persisted on - # the recording. Follow-up work may unify this with `get_channel_locations` by - # reading directly from the attached probegroup. groups = self.get_property("group", ids=channel_ids) return groups From 5d1f57d1b4fdbc4ce44dc50f783353b22b5990ae Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 16:32:02 -0600 Subject: [PATCH 09/20] rename --- .../core/baserecordingsnippets.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index b12a753ae9..d2d3c21716 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -177,7 +177,7 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): # slice + reorder probegroup so contact order matches the recording's channel order, and reset wiring to arange probegroup = probegroup.get_slice(sorted_contact_indices) probegroup.set_global_device_channel_indices(np.arange(len(device_channel_indices), dtype="int64")) - contact_vector = probegroup._contact_vector + probe_as_numpy_array = probegroup._contact_vector # create recording : channel slice or clone or self if in_place: @@ -194,14 +194,14 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): # duplicate positions to "location" property so SpikeInterface-level readers keep working ndim = probegroup.ndim - locations = np.zeros((contact_vector.size, ndim), dtype="float64") + locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") for i, dim in enumerate(["x", "y", "z"][:ndim]): - locations[:, i] = contact_vector[dim] + locations[:, i] = probe_as_numpy_array[dim] sub_recording.set_property("location", locations, ids=None) # derive groups from contact_vector - has_shank_id = "shank_ids" in contact_vector.dtype.fields - has_contact_side = "contact_sides" in contact_vector.dtype.fields + has_shank_id = "shank_ids" in probe_as_numpy_array.dtype.fields + has_contact_side = "contact_sides" in probe_as_numpy_array.dtype.fields if group_mode == "auto": group_keys = ["probe_index"] if has_shank_id: @@ -219,12 +219,12 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): group_keys = ["probe_index", "shank_ids", "contact_sides"] else: group_keys = ["probe_index", "contact_sides"] - groups = np.zeros(contact_vector.size, dtype="int64") - unique_keys = np.unique(contact_vector[group_keys]) + groups = np.zeros(probe_as_numpy_array.size, dtype="int64") + unique_keys = np.unique(probe_as_numpy_array[group_keys]) for group, a in enumerate(unique_keys): - mask = np.ones(contact_vector.size, dtype=bool) + mask = np.ones(probe_as_numpy_array.size, dtype=bool) for k in group_keys: - mask &= contact_vector[k] == a[k] + mask &= probe_as_numpy_array[k] == a[k] groups[mask] = group sub_recording.set_property("group", groups, ids=None) From c02e7602bfff3e43edae4d23681229d6a85e0970 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 16:38:12 -0600 Subject: [PATCH 10/20] more backwards compatability --- .../core/baserecordingsnippets.py | 31 ++++++++++++++----- 1 file changed, 23 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index d2d3c21716..92c5bfb858 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -228,6 +228,17 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): groups[mask] = group sub_recording.set_property("group", groups, ids=None) + # TODO discuss backwards compatibility: mirror probe-level annotations and planar + # contours as recording-level annotations so external code that reads these keys + # keeps working. The canonical source is now `probe.annotations` and + # `probe.probe_planar_contour` on the attached probegroup. + probes_info = [probe.annotations for probe in probegroup.probes] + sub_recording.annotate(probes_info=probes_info) + for probe_index, probe in enumerate(probegroup.probes): + contour = probe.probe_planar_contour + if contour is not None: + sub_recording.set_annotation(f"probe_{probe_index}_planar_contour", contour, overwrite=True) + return sub_recording def get_probe(self): @@ -322,14 +333,18 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarra if channel_ids is None: channel_ids = self.get_channel_ids() channel_indices = self.ids_to_indices(channel_ids) - if not self.has_probe(): - raise ValueError("get_channel_locations(..) needs a probe to be attached to the recording") - contact_vector = self._probegroup._contact_vector - ndim = len(axes) - all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") - for i, dim in enumerate(axes): - all_positions[:, i] = contact_vector[dim] - return all_positions[channel_indices] + if self.has_probe(): + contact_vector = self._probegroup._contact_vector + ndim = len(axes) + all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") + for i, dim in enumerate(axes): + all_positions[:, i] = contact_vector[dim] + return all_positions[channel_indices] + locations = self.get_property("location") + if locations is None: + raise Exception("There are no channel locations") + locations = np.asarray(locations)[channel_indices] + return select_axes(locations, axes) def has_3d_locations(self) -> bool: return self.get_property("location").shape[1] == 3 From 591be6353c203b02d077856a1065cddb34f2c014 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 16:48:33 -0600 Subject: [PATCH 11/20] testing --- pyproject.toml | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e3eb85b6e5..48ca4333eb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -127,7 +127,8 @@ test_core = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge + "probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # for slurm jobs, @@ -139,7 +140,8 @@ test_extractors = [ "pooch>=1.8.2", "datalad>=1.0.2", # Commenting out for release - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge + "probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] @@ -190,7 +192,8 @@ test = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge + "probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # for slurm jobs @@ -219,7 +222,8 @@ docs = [ "huggingface_hub", # For automated curation # for release we need pypi, so this needs to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + # TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge + "probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] From ed22a738e09304362f01a8b8c150f587a7660691 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 18:10:45 -0600 Subject: [PATCH 12/20] beahvior for 0 channel recording --- .../core/baserecordingsnippets.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 92c5bfb858..502a59ae61 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -174,10 +174,13 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): new_channel_ids = self.get_channel_ids()[device_channel_indices] + # capture ndim before slicing; get_slice with an empty selection yields a probegroup + # with no probes, on which `.ndim` raises + ndim = probegroup.ndim + # slice + reorder probegroup so contact order matches the recording's channel order, and reset wiring to arange probegroup = probegroup.get_slice(sorted_contact_indices) probegroup.set_global_device_channel_indices(np.arange(len(device_channel_indices), dtype="int64")) - probe_as_numpy_array = probegroup._contact_vector # create recording : channel slice or clone or self if in_place: @@ -192,8 +195,20 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): sub_recording._probegroup = probegroup + # TODO: revisit whether set_probe with a fully unconnected probe should raise + # instead of returning a zero-channel recording. Preserved here for backwards + # compatibility with a test in test_BaseRecording; that test case should be + # peeled into its own named test so this assumption is easy to find and + # discuss when we decide to tighten the behaviour. + if len(device_channel_indices) == 0: + sub_recording.set_property("location", np.zeros((0, ndim), dtype="float64"), ids=None) + sub_recording.set_property("group", np.zeros(0, dtype="int64"), ids=None) + sub_recording.annotate(probes_info=[]) + return sub_recording + + probe_as_numpy_array = probegroup._contact_vector + # duplicate positions to "location" property so SpikeInterface-level readers keep working - ndim = probegroup.ndim locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") for i, dim in enumerate(["x", "y", "z"][:ndim]): locations[:, i] = probe_as_numpy_array[dim] From 565d7759eb47e1f0b2f45926ee7f992366fdc488 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 18:21:39 -0600 Subject: [PATCH 13/20] more fixes --- .../core/channelsaggregationrecording.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 2e5deb8703..64aa7e5bc2 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -97,12 +97,20 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record # aggregate probegroups across the inputs and attach via the canonical path if all(rec.has_probe() for rec in recording_list): aggregated_probegroup = ProbeGroup() + offset = 0 for rec in recording_list: for probe in rec.get_probegroup().probes: # round-trip through to_dict/from_dict because Probe.copy() drops contact_ids # and annotations (tracked in probeinterface #421) - aggregated_probegroup.add_probe(Probe.from_dict(probe.to_dict(array_as_list=False))) - aggregated_probegroup.set_global_device_channel_indices(np.arange(self.get_num_channels(), dtype="int64")) + probe_copy = Probe.from_dict(probe.to_dict(array_as_list=False)) + # assign non-colliding device_channel_indices before add_probe so the + # cross-probe uniqueness check does not fire on children that share + # child-local wiring (each sub-recording's probe was reset to arange + # when it was created via set_probe) + n = probe_copy.get_contact_count() + probe_copy.set_device_channel_indices(np.arange(offset, offset + n, dtype="int64")) + aggregated_probegroup.add_probe(probe_copy) + offset += n self.set_probegroup(aggregated_probegroup, in_place=True) # if locations are present, check that they are all different! From 5ef02216f0d7c1c666b434998c400e13af866919 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 18:44:42 -0600 Subject: [PATCH 14/20] second fix --- .../core/channelsaggregationrecording.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 64aa7e5bc2..90be729664 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -94,24 +94,23 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record for prop_name, prop_values in property_dict.items(): self.set_property(key=prop_name, values=prop_values) - # aggregate probegroups across the inputs and attach via the canonical path + # split_by resets each child probe's device_channel_indices, so the information + # of which contact was connected to which channel of the parent is lost by the + # time we aggregate. We rebuild a globally-unique wiring via per-probe offsets + # and skip set_probegroup because children also share contact positions. if all(rec.has_probe() for rec in recording_list): aggregated_probegroup = ProbeGroup() offset = 0 for rec in recording_list: for probe in rec.get_probegroup().probes: - # round-trip through to_dict/from_dict because Probe.copy() drops contact_ids - # and annotations (tracked in probeinterface #421) + # round-trip through to_dict/from_dict because Probe.copy() drops + # contact_ids and annotations (probeinterface #421) probe_copy = Probe.from_dict(probe.to_dict(array_as_list=False)) - # assign non-colliding device_channel_indices before add_probe so the - # cross-probe uniqueness check does not fire on children that share - # child-local wiring (each sub-recording's probe was reset to arange - # when it was created via set_probe) n = probe_copy.get_contact_count() probe_copy.set_device_channel_indices(np.arange(offset, offset + n, dtype="int64")) aggregated_probegroup.add_probe(probe_copy) offset += n - self.set_probegroup(aggregated_probegroup, in_place=True) + self._probegroup = aggregated_probegroup # if locations are present, check that they are all different! if "location" in self.get_property_keys(): From 07f4e9e79d1f5c889d1c45a9403fe0a3fe0e4a9c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 19:07:49 -0600 Subject: [PATCH 15/20] remove non-overallaping redundant check --- src/spikeinterface/core/sortinganalyzer.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 8e16757bcc..ad444fca4b 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -23,7 +23,7 @@ from spikeinterface.core import BaseRecording, BaseSorting, aggregate_channels, aggregate_units from spikeinterface.core.waveform_tools import has_exceeding_spikes -from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match +from .recording_tools import get_rec_attributes, do_recording_attributes_match from .core_tools import ( check_json, retrieve_importing_provenance, @@ -363,10 +363,6 @@ def create( f"recording: {recording.sampling_frequency} - sorting: {sorting.sampling_frequency}. " "Ensure that you are associating the correct Recording and Sorting when creating a SortingAnalyzer." ) - # check that multiple probes are non-overlapping - all_probes = recording.get_probegroup().probes - check_probe_do_not_overlap(all_probes) - if has_exceeding_spikes(sorting=sorting, recording=recording): warnings.warn( "Your sorting has spikes with samples times greater than your recording length. These spikes have been removed." From e6129bc33a50c0892a3c1827de31b7135be0281c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 19:13:52 -0600 Subject: [PATCH 16/20] another fallack --- src/spikeinterface/core/baserecordingsnippets.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 502a59ae61..63413c8768 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -267,7 +267,18 @@ def get_probes(self): def get_probegroup(self): if self._probegroup is None: - raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") + # Backwards-compat fallback: pre-migration get_probegroup synthesised a dummy + # probe from the "location" property when no probe had been attached. Callers + # (e.g. sparsity.py) rely on this for recordings that have locations but no + # probe. + positions = self.get_property("location") + if positions is None: + raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") + warn("There is no Probe attached to this recording. Creating a dummy one with contact positions") + probe = self.create_dummy_probe_from_locations(positions) + pg = ProbeGroup() + pg.add_probe(probe) + return copy.deepcopy(pg) # Return a deepcopy for backwards compatibility: pre-migration `main` reconstructed # a fresh `ProbeGroup` from the stored structured array on each call, so external # callers relied on value semantics. Handing out the live `_probegroup` would be a From 6eb09a65ab671533e073dd19e64e8abb8e2e91a4 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 20:01:22 -0600 Subject: [PATCH 17/20] propgate to children --- src/spikeinterface/preprocessing/basepreprocessor.py | 7 +++++++ .../preprocessing/tests/test_interpolate_bad_channels.py | 6 +++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/preprocessing/basepreprocessor.py b/src/spikeinterface/preprocessing/basepreprocessor.py index 64d57d3637..4e18516a80 100644 --- a/src/spikeinterface/preprocessing/basepreprocessor.py +++ b/src/spikeinterface/preprocessing/basepreprocessor.py @@ -21,6 +21,13 @@ def __init__(self, recording, sampling_frequency=None, channel_ids=None, dtype=N recording.copy_metadata(self, only_main=False, ids=channel_ids) self._parent = recording + # Propagate the attached probegroup. copy_metadata only handles annotations + # and properties; `_probegroup` is a direct attribute and needs its own path. + # Subclasses that change channels (e.g. slicing) should override by slicing + # the probegroup themselves via set_probegroup. + if getattr(recording, "_probegroup", None) is not None and channel_ids is None: + self._probegroup = recording._probegroup + # self._kwargs have to be handled in subclass diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index 6b40548bc4..fd835df5c7 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -126,7 +126,7 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan # distribute default probe locations across 4 shanks if set rng = np.random.default_rng(seed=None) x = rng.choice(shanks, num_channels) - probe = recording.get_probegroup().probes[0] + probe = recording._probegroup.probes[0] probe._contact_positions[:, 0] = x recording._probegroup._build_contact_vector() recording.set_property("location", recording.get_channel_locations()) @@ -172,7 +172,7 @@ def test_output_values(): [5, 5, 5, 7, 3], ] # all others equal distance away. # Overwrite the probe information with the new locations - probe = recording.get_probegroup().probes[0] + probe = recording._probegroup.probes[0] for idx, (x, y) in enumerate(zip(*new_probe_locs)): probe._contact_positions[idx, 0] = x probe._contact_positions[idx, 1] = y @@ -191,7 +191,7 @@ def test_output_values(): # Shift the last channel position so that it is 4 units, rather than 2 # away. Setting sigma_um = p = 1 allows easy calculation of the expected # weights. - probe = recording.get_probegroup().probes[0] + probe = recording._probegroup.probes[0] probe._contact_positions[-1, 0] = 5 probe._contact_positions[-1, 1] = 9 recording._probegroup._build_contact_vector() From 1f4afee9631addc2007aa45f7932d922b009cf4f Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 20:19:21 -0600 Subject: [PATCH 18/20] fix tests --- src/spikeinterface/extractors/tests/test_iblextractors.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index ff21c7a3c7..dfbf5d714d 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -68,11 +68,11 @@ def test_channel_ids(self): def test_gains(self): expected_gains = 2.34375 * np.ones(shape=384) - assert_array_equal(x=self.recording.get_channel_gains(), y=expected_gains) + assert_array_equal(self.recording.get_channel_gains(), expected_gains) def test_offsets(self): expected_offsets = np.zeros(shape=384) - assert_array_equal(x=self.recording.get_channel_offsets(), y=expected_offsets) + assert_array_equal(self.recording.get_channel_offsets(), expected_offsets) def test_probe_representation(self): probe = self.recording.get_probe() @@ -141,11 +141,11 @@ def test_channel_ids(self): def test_gains(self): expected_gains = np.concatenate([2.34375 * np.ones(shape=384), [1171.875]]) - assert_array_equal(x=self.recording.get_channel_gains(), y=expected_gains) + assert_array_equal(self.recording.get_channel_gains(), expected_gains) def test_offsets(self): expected_offsets = np.zeros(shape=385) - assert_array_equal(x=self.recording.get_channel_offsets(), y=expected_offsets) + assert_array_equal(self.recording.get_channel_offsets(), expected_offsets) def test_probe_representation(self): expected_exception = ValueError From eb07eeaea9b614d866a735e05ab3408f6fc08dae Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 23:01:57 -0600 Subject: [PATCH 19/20] fixes --- src/spikeinterface/preprocessing/tests/test_filter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/preprocessing/tests/test_filter.py b/src/spikeinterface/preprocessing/tests/test_filter.py index 398b6cbc0e..e95b456542 100644 --- a/src/spikeinterface/preprocessing/tests/test_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_filter.py @@ -50,7 +50,7 @@ def test_causal_filter_main_kwargs(self, recording_and_data): # Then, change all kwargs to ensure they are propagated # and check the backwards version. - options["band"] = [671] + options["band"] = 671 options["btype"] = "highpass" options["filter_order"] = 8 options["ftype"] = "bessel" From cc9e9b0e83c3adac4d98ee79b65fe0f09292926c Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Sun, 19 Apr 2026 23:44:05 -0600 Subject: [PATCH 20/20] another numpy fix --- .../postprocessing/tests/test_principal_component.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 77bff7a3d8..0e65bb2338 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -101,7 +101,7 @@ def test_get_projections(self, sparse): random_spikes_ext = sorting_analyzer.get_extension("random_spikes") random_spikes_indices = random_spikes_ext.get_data() - unit_ids_num_random_spikes = np.sum(random_spikes_ext.params["max_spikes_per_unit"] for _ in some_unit_ids) + unit_ids_num_random_spikes = sum(random_spikes_ext.params["max_spikes_per_unit"] for _ in some_unit_ids) # this should be all spikes all channels some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=None)