diff --git a/pyproject.toml b/pyproject.toml index 5f9ae34281..443cf0f360 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -128,7 +128,8 @@ test_core = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge + "probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # for slurm jobs, @@ -140,7 +141,8 @@ test_extractors = [ "pooch>=1.8.2", "datalad>=1.0.2", # Commenting out for release - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge + "probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", ] @@ -191,7 +193,8 @@ test = [ # for github test : probeinterface and neo from master # for release we need pypi, so this need to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", + # TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge + "probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # for slurm jobs @@ -220,7 +223,8 @@ docs = [ "huggingface_hub", # For automated curation # for release we need pypi, so this needs to be commented - "probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version + # TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge + "probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle", "neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version ] diff --git a/src/spikeinterface/core/base.py b/src/spikeinterface/core/base.py index fcbafdb6bf..6e61ae894f 100644 --- a/src/spikeinterface/core/base.py +++ b/src/spikeinterface/core/base.py @@ -574,6 +574,9 @@ def to_dict( folder_metadata = Path(folder_metadata).resolve().absolute().relative_to(relative_to) dump_dict["folder_metadata"] = str(folder_metadata) + if getattr(self, "_probegroup", None) is not None: + dump_dict["probegroup"] = self._probegroup.to_dict(array_as_list=True) + return dump_dict @staticmethod @@ -1155,9 +1158,54 @@ def _load_extractor_from_dict(dic) -> "BaseExtractor": for k, v in dic["properties"].items(): extractor.set_property(k, v) + if "probegroup" in dic: + from probeinterface import ProbeGroup + + probegroup = ProbeGroup.from_dict(dic["probegroup"]) + if hasattr(extractor, "set_probegroup"): + extractor.set_probegroup(probegroup, in_place=True) + else: + extractor._probegroup = probegroup + elif "contact_vector" in dic.get("properties", {}): + _restore_probegroup_from_legacy_contact_vector(extractor) + return extractor +def _restore_probegroup_from_legacy_contact_vector(extractor) -> None: + """ + Reconstruct a `ProbeGroup` from the legacy `contact_vector` property. + + Recordings saved before the probegroup refactor stored the probe as a structured numpy + array under the `contact_vector` property, with probe-level annotations under a separate + `probes_info` annotation and per-probe planar contours under `probe_{i}_planar_contour` + annotations. This function reconstructs a `ProbeGroup` from those legacy fields, attaches + it via the canonical `set_probegroup` path, and removes the legacy property so the new + and old representations do not coexist on the loaded extractor. + """ + from probeinterface import ProbeGroup + + contact_vector_array = extractor.get_property("contact_vector") + probegroup = ProbeGroup.from_numpy(contact_vector_array) + + if "probes_info" in extractor.get_annotation_keys(): + probes_info = extractor.get_annotation("probes_info") + for probe, probe_info in zip(probegroup.probes, probes_info): + probe.annotations = probe_info + + for probe_index, probe in enumerate(probegroup.probes): + contour = extractor._annotations.get(f"probe_{probe_index}_planar_contour") + if contour is not None: + probe.set_planar_contour(contour) + + if hasattr(extractor, "set_probegroup"): + extractor.set_probegroup(probegroup, in_place=True) + else: + extractor._probegroup = probegroup + + extractor._properties.pop("contact_vector", None) + + def _get_class_from_string(class_string): class_name = class_string.split(".")[-1] module = ".".join(class_string.split(".")[:-1]) diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 068d2a047c..2993d5e135 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -387,7 +387,7 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) @@ -409,7 +409,7 @@ def _extra_metadata_from_folder(self, folder): def _extra_metadata_to_folder(self, folder): # save probe - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() write_probeinterface(folder / "probe.json", probegroup) diff --git a/src/spikeinterface/core/baserecordingsnippets.py b/src/spikeinterface/core/baserecordingsnippets.py index 58e91ec35c..63413c8768 100644 --- a/src/spikeinterface/core/baserecordingsnippets.py +++ b/src/spikeinterface/core/baserecordingsnippets.py @@ -1,3 +1,4 @@ +import copy from pathlib import Path import numpy as np @@ -19,6 +20,7 @@ def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype BaseExtractor.__init__(self, channel_ids) self._sampling_frequency = float(sampling_frequency) self._dtype = np.dtype(dtype) + self._probegroup = None @property def channel_ids(self): @@ -51,7 +53,7 @@ def has_scaleable_traces(self) -> bool: return True def has_probe(self) -> bool: - return "contact_vector" in self.get_property_keys() + return self._probegroup is not None def has_channel_location(self) -> bool: return self.has_probe() or "location" in self.get_property_keys() @@ -145,21 +147,19 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): probe.device_channel_indices is not None for probe in probegroup.probes ), "Probe must have device_channel_indices" - # this is a vector with complex fileds (dataframe like) that handle all contact attr - probe_as_numpy_array = probegroup.to_numpy(complete=True) - - # keep only connected contact ( != -1) - keep = probe_as_numpy_array["device_channel_indices"] >= 0 - if np.any(~keep): + # identify connected contacts and their channel-order + global_device_channel_indices = probegroup.get_global_device_channel_indices()["device_channel_indices"] + connected_mask = global_device_channel_indices >= 0 + if np.any(~connected_mask): warn("The given probes have unconnected contacts: they are removed") - probe_as_numpy_array = probe_as_numpy_array[keep] - - device_channel_indices = probe_as_numpy_array["device_channel_indices"] - order = np.argsort(device_channel_indices) - device_channel_indices = device_channel_indices[order] + connected_contact_indices = np.where(connected_mask)[0] + connected_channel_values = global_device_channel_indices[connected_mask] + order = np.argsort(connected_channel_values) + sorted_contact_indices = connected_contact_indices[order] + device_channel_indices = connected_channel_values[order] - # check TODO: Where did this came from? + # validate indices fit the recording number_of_device_channel_indices = np.max(list(device_channel_indices) + [0]) if number_of_device_channel_indices >= self.get_num_channels(): error_msg = ( @@ -173,8 +173,14 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): raise ValueError(error_msg) new_channel_ids = self.get_channel_ids()[device_channel_indices] - probe_as_numpy_array = probe_as_numpy_array[order] - probe_as_numpy_array["device_channel_indices"] = np.arange(probe_as_numpy_array.size, dtype="int64") + + # capture ndim before slicing; get_slice with an empty selection yields a probegroup + # with no probes, on which `.ndim` raises + ndim = probegroup.ndim + + # slice + reorder probegroup so contact order matches the recording's channel order, and reset wiring to arange + probegroup = probegroup.get_slice(sorted_contact_indices) + probegroup.set_global_device_channel_indices(np.arange(len(device_channel_indices), dtype="int64")) # create recording : channel slice or clone or self if in_place: @@ -187,23 +193,28 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): else: sub_recording = self.select_channels(new_channel_ids) - # create a vector that handle all contacts in property - sub_recording.set_property("contact_vector", probe_as_numpy_array, ids=None) + sub_recording._probegroup = probegroup - # planar_contour is saved in annotations - for probe_index, probe in enumerate(probegroup.probes): - contour = probe.probe_planar_contour - if contour is not None: - sub_recording.set_annotation(f"probe_{probe_index}_planar_contour", contour, overwrite=True) + # TODO: revisit whether set_probe with a fully unconnected probe should raise + # instead of returning a zero-channel recording. Preserved here for backwards + # compatibility with a test in test_BaseRecording; that test case should be + # peeled into its own named test so this assumption is easy to find and + # discuss when we decide to tighten the behaviour. + if len(device_channel_indices) == 0: + sub_recording.set_property("location", np.zeros((0, ndim), dtype="float64"), ids=None) + sub_recording.set_property("group", np.zeros(0, dtype="int64"), ids=None) + sub_recording.annotate(probes_info=[]) + return sub_recording - # duplicate positions to "locations" property - ndim = probegroup.ndim + probe_as_numpy_array = probegroup._contact_vector + + # duplicate positions to "location" property so SpikeInterface-level readers keep working locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64") for i, dim in enumerate(["x", "y", "z"][:ndim]): locations[:, i] = probe_as_numpy_array[dim] sub_recording.set_property("location", locations, ids=None) - # handle groups + # derive groups from contact_vector has_shank_id = "shank_ids" in probe_as_numpy_array.dtype.fields has_contact_side = "contact_sides" in probe_as_numpy_array.dtype.fields if group_mode == "auto": @@ -232,11 +243,16 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False): groups[mask] = group sub_recording.set_property("group", groups, ids=None) - # add probe annotations to recording - probes_info = [] - for probe in probegroup.probes: - probes_info.append(probe.annotations) + # TODO discuss backwards compatibility: mirror probe-level annotations and planar + # contours as recording-level annotations so external code that reads these keys + # keeps working. The canonical source is now `probe.annotations` and + # `probe.probe_planar_contour` on the attached probegroup. + probes_info = [probe.annotations for probe in probegroup.probes] sub_recording.annotate(probes_info=probes_info) + for probe_index, probe in enumerate(probegroup.probes): + contour = probe.probe_planar_contour + if contour is not None: + sub_recording.set_annotation(f"probe_{probe_index}_planar_contour", contour, overwrite=True) return sub_recording @@ -250,30 +266,24 @@ def get_probes(self): return probegroup.probes def get_probegroup(self): - arr = self.get_property("contact_vector") - if arr is None: + if self._probegroup is None: + # Backwards-compat fallback: pre-migration get_probegroup synthesised a dummy + # probe from the "location" property when no probe had been attached. Callers + # (e.g. sparsity.py) rely on this for recordings that have locations but no + # probe. positions = self.get_property("location") if positions is None: raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.") - else: - warn("There is no Probe attached to this recording. Creating a dummy one with contact positions") - probe = self.create_dummy_probe_from_locations(positions) - # probe.create_auto_shape() - probegroup = ProbeGroup() - probegroup.add_probe(probe) - else: - probegroup = ProbeGroup.from_numpy(arr) - - if "probes_info" in self.get_annotation_keys(): - probes_info = self.get_annotation("probes_info") - for probe, probe_info in zip(probegroup.probes, probes_info): - probe.annotations = probe_info - - for probe_index, probe in enumerate(probegroup.probes): - contour = self.get_annotation(f"probe_{probe_index}_planar_contour") - if contour is not None: - probe.set_planar_contour(contour) - return probegroup + warn("There is no Probe attached to this recording. Creating a dummy one with contact positions") + probe = self.create_dummy_probe_from_locations(positions) + pg = ProbeGroup() + pg.add_probe(probe) + return copy.deepcopy(pg) + # Return a deepcopy for backwards compatibility: pre-migration `main` reconstructed + # a fresh `ProbeGroup` from the stored structured array on each call, so external + # callers relied on value semantics. Handing out the live `_probegroup` would be a + # silent behavioural change. + return copy.deepcopy(self._probegroup) def _extra_metadata_from_folder(self, folder): # load probe @@ -284,7 +294,7 @@ def _extra_metadata_from_folder(self, folder): def _extra_metadata_to_folder(self, folder): # save probe - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() write_probeinterface(folder / "probe.json", probegroup) @@ -341,7 +351,7 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params self.set_probe(probe, in_place=True) def set_channel_locations(self, locations, channel_ids=None): - if self.get_property("contact_vector") is not None: + if self.has_probe(): raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)") self.set_property("location", locations, ids=channel_ids) @@ -349,21 +359,18 @@ def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarra if channel_ids is None: channel_ids = self.get_channel_ids() channel_indices = self.ids_to_indices(channel_ids) - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - # here we bypass the probe reconstruction so this works both for probe and probegroup + if self.has_probe(): + contact_vector = self._probegroup._contact_vector ndim = len(axes) all_positions = np.zeros((contact_vector.size, ndim), dtype="float64") for i, dim in enumerate(axes): all_positions[:, i] = contact_vector[dim] - positions = all_positions[channel_indices] - return positions - else: - locations = self.get_property("location") - if locations is None: - raise Exception("There are no channel locations") - locations = np.asarray(locations)[channel_indices] - return select_axes(locations, axes) + return all_positions[channel_indices] + locations = self.get_property("location") + if locations is None: + raise Exception("There are no channel locations") + locations = np.asarray(locations)[channel_indices] + return select_axes(locations, axes) def has_3d_locations(self) -> bool: return self.get_property("location").shape[1] == 3 diff --git a/src/spikeinterface/core/basesnippets.py b/src/spikeinterface/core/basesnippets.py index b56a093ccc..fa47365200 100644 --- a/src/spikeinterface/core/basesnippets.py +++ b/src/spikeinterface/core/basesnippets.py @@ -259,7 +259,7 @@ def _save(self, format="npy", **save_kwargs): else: raise ValueError(f"format {format} not supported") - if self.get_property("contact_vector") is not None: + if self.has_probe(): probegroup = self.get_probegroup() cached.set_probegroup(probegroup) diff --git a/src/spikeinterface/core/channelsaggregationrecording.py b/src/spikeinterface/core/channelsaggregationrecording.py index 697aab875e..90be729664 100644 --- a/src/spikeinterface/core/channelsaggregationrecording.py +++ b/src/spikeinterface/core/channelsaggregationrecording.py @@ -2,6 +2,8 @@ import numpy as np +from probeinterface import Probe, ProbeGroup + from .baserecording import BaseRecording, BaseRecordingSegment @@ -90,11 +92,26 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record break for prop_name, prop_values in property_dict.items(): - if prop_name == "contact_vector": - # remap device channel indices correctly - prop_values["device_channel_indices"] = np.arange(self.get_num_channels()) self.set_property(key=prop_name, values=prop_values) + # split_by resets each child probe's device_channel_indices, so the information + # of which contact was connected to which channel of the parent is lost by the + # time we aggregate. We rebuild a globally-unique wiring via per-probe offsets + # and skip set_probegroup because children also share contact positions. + if all(rec.has_probe() for rec in recording_list): + aggregated_probegroup = ProbeGroup() + offset = 0 + for rec in recording_list: + for probe in rec.get_probegroup().probes: + # round-trip through to_dict/from_dict because Probe.copy() drops + # contact_ids and annotations (probeinterface #421) + probe_copy = Probe.from_dict(probe.to_dict(array_as_list=False)) + n = probe_copy.get_contact_count() + probe_copy.set_device_channel_indices(np.arange(offset, offset + n, dtype="int64")) + aggregated_probegroup.add_probe(probe_copy) + offset += n + self._probegroup = aggregated_probegroup + # if locations are present, check that they are all different! if "location" in self.get_property_keys(): location_tuple = [tuple(loc) for loc in self.get_property("location")] diff --git a/src/spikeinterface/core/channelslice.py b/src/spikeinterface/core/channelslice.py index de693d5c26..f7d498db04 100644 --- a/src/spikeinterface/core/channelslice.py +++ b/src/spikeinterface/core/channelslice.py @@ -61,11 +61,12 @@ def __init__(self, parent_recording, channel_ids=None, renamed_channel_ids=None) parent_recording.copy_metadata(self, only_main=False, ids=self._channel_ids) self._parent = parent_recording - # change the wiring of the probe - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + # slice the probegroup to the retained channels and attach via the canonical path + if parent_recording.has_probe(): + parent_probegroup = parent_recording.get_probegroup() + sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) + sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) + self.set_probegroup(sliced_probegroup, in_place=True) # update dump dict self._kwargs = { @@ -151,11 +152,12 @@ def __init__(self, parent_snippets, channel_ids=None, renamed_channel_ids=None): # copy annotation and properties parent_snippets.copy_metadata(self, only_main=False, ids=self._channel_ids) - # change the wiring of the probe - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + # slice the probegroup to the retained channels and attach via the canonical path + if parent_snippets.has_probe(): + parent_probegroup = parent_snippets.get_probegroup() + sliced_probegroup = parent_probegroup.get_slice(self._parent_channel_indices) + sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) + self.set_probegroup(sliced_probegroup, in_place=True) # update dump dict self._kwargs = { diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 8e16757bcc..ad444fca4b 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -23,7 +23,7 @@ from spikeinterface.core import BaseRecording, BaseSorting, aggregate_channels, aggregate_units from spikeinterface.core.waveform_tools import has_exceeding_spikes -from .recording_tools import check_probe_do_not_overlap, get_rec_attributes, do_recording_attributes_match +from .recording_tools import get_rec_attributes, do_recording_attributes_match from .core_tools import ( check_json, retrieve_importing_provenance, @@ -363,10 +363,6 @@ def create( f"recording: {recording.sampling_frequency} - sorting: {sorting.sampling_frequency}. " "Ensure that you are associating the correct Recording and Sorting when creating a SortingAnalyzer." ) - # check that multiple probes are non-overlapping - all_probes = recording.get_probegroup().probes - check_probe_do_not_overlap(all_probes) - if has_exceeding_spikes(sorting=sorting, recording=recording): warnings.warn( "Your sorting has spikes with samples times greater than your recording length. These spikes have been removed." diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index a51063af3e..26c01287ad 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -503,7 +503,7 @@ def add_recording_to_zarr_group( ) # save probe - if recording.get_property("contact_vector") is not None: + if recording.has_probe(): probegroup = recording.get_probegroup() zarr_group.attrs["probe"] = check_json(probegroup.to_dict(array_as_list=True)) diff --git a/src/spikeinterface/extractors/neoextractors/biocam.py b/src/spikeinterface/extractors/neoextractors/biocam.py index 8d1fac0c72..c2ceb66523 100644 --- a/src/spikeinterface/extractors/neoextractors/biocam.py +++ b/src/spikeinterface/extractors/neoextractors/biocam.py @@ -1,5 +1,6 @@ from pathlib import Path +import numpy as np import probeinterface from spikeinterface.core.core_tools import define_function_from_class @@ -71,8 +72,9 @@ def __init__( probe_kwargs["electrode_width"] = electrode_width probe = probeinterface.read_3brain(file_path, **probe_kwargs) self.set_probe(probe, in_place=True) - self.set_property("row", self.get_property("contact_vector")["row"]) - self.set_property("col", self.get_property("contact_vector")["col"]) + probe = self.get_probegroup().probes[0] + self.set_property("row", np.asarray(probe.contact_annotations["row"])) + self.set_property("col", np.asarray(probe.contact_annotations["col"])) self._kwargs.update( { diff --git a/src/spikeinterface/extractors/neoextractors/maxwell.py b/src/spikeinterface/extractors/neoextractors/maxwell.py index 932ecee106..875422d00b 100644 --- a/src/spikeinterface/extractors/neoextractors/maxwell.py +++ b/src/spikeinterface/extractors/neoextractors/maxwell.py @@ -75,7 +75,8 @@ def __init__( rec_name = self.neo_reader.rec_name probe = probeinterface.read_maxwell(file_path, well_name=well_name, rec_name=rec_name) self.set_probe(probe, in_place=True) - self.set_property("electrode", self.get_property("contact_vector")["electrode"]) + probe = self.get_probegroup().probes[0] + self.set_property("electrode", np.asarray(probe.contact_annotations["electrode"])) self._kwargs.update(dict(file_path=str(Path(file_path).absolute()), rec_name=rec_name)) @classmethod diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 5306de2441..dfbf5d714d 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -68,11 +68,11 @@ def test_channel_ids(self): def test_gains(self): expected_gains = 2.34375 * np.ones(shape=384) - assert_array_equal(x=self.recording.get_channel_gains(), y=expected_gains) + assert_array_equal(self.recording.get_channel_gains(), expected_gains) def test_offsets(self): expected_offsets = np.zeros(shape=384) - assert_array_equal(x=self.recording.get_channel_offsets(), y=expected_offsets) + assert_array_equal(self.recording.get_channel_offsets(), expected_offsets) def test_probe_representation(self): probe = self.recording.get_probe() @@ -84,7 +84,6 @@ def test_property_keys(self): expected_property_keys = [ "gain_to_uV", "offset_to_uV", - "contact_vector", "location", "group", "shank", @@ -142,11 +141,11 @@ def test_channel_ids(self): def test_gains(self): expected_gains = np.concatenate([2.34375 * np.ones(shape=384), [1171.875]]) - assert_array_equal(x=self.recording.get_channel_gains(), y=expected_gains) + assert_array_equal(self.recording.get_channel_gains(), expected_gains) def test_offsets(self): expected_offsets = np.zeros(shape=385) - assert_array_equal(x=self.recording.get_channel_offsets(), y=expected_offsets) + assert_array_equal(self.recording.get_channel_offsets(), expected_offsets) def test_probe_representation(self): expected_exception = ValueError diff --git a/src/spikeinterface/postprocessing/tests/test_principal_component.py b/src/spikeinterface/postprocessing/tests/test_principal_component.py index 77bff7a3d8..0e65bb2338 100644 --- a/src/spikeinterface/postprocessing/tests/test_principal_component.py +++ b/src/spikeinterface/postprocessing/tests/test_principal_component.py @@ -101,7 +101,7 @@ def test_get_projections(self, sparse): random_spikes_ext = sorting_analyzer.get_extension("random_spikes") random_spikes_indices = random_spikes_ext.get_data() - unit_ids_num_random_spikes = np.sum(random_spikes_ext.params["max_spikes_per_unit"] for _ in some_unit_ids) + unit_ids_num_random_spikes = sum(random_spikes_ext.params["max_spikes_per_unit"] for _ in some_unit_ids) # this should be all spikes all channels some_projections, spike_unit_index = ext.get_some_projections(channel_ids=None, unit_ids=None) diff --git a/src/spikeinterface/preprocessing/basepreprocessor.py b/src/spikeinterface/preprocessing/basepreprocessor.py index 64d57d3637..4e18516a80 100644 --- a/src/spikeinterface/preprocessing/basepreprocessor.py +++ b/src/spikeinterface/preprocessing/basepreprocessor.py @@ -21,6 +21,13 @@ def __init__(self, recording, sampling_frequency=None, channel_ids=None, dtype=N recording.copy_metadata(self, only_main=False, ids=channel_ids) self._parent = recording + # Propagate the attached probegroup. copy_metadata only handles annotations + # and properties; `_probegroup` is a direct attribute and needs its own path. + # Subclasses that change channels (e.g. slicing) should override by slicing + # the probegroup themselves via set_probegroup. + if getattr(recording, "_probegroup", None) is not None and channel_ids is None: + self._probegroup = recording._probegroup + # self._kwargs have to be handled in subclass diff --git a/src/spikeinterface/preprocessing/tests/test_filter.py b/src/spikeinterface/preprocessing/tests/test_filter.py index 398b6cbc0e..e95b456542 100644 --- a/src/spikeinterface/preprocessing/tests/test_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_filter.py @@ -50,7 +50,7 @@ def test_causal_filter_main_kwargs(self, recording_and_data): # Then, change all kwargs to ensure they are propagated # and check the backwards version. - options["band"] = [671] + options["band"] = 671 options["btype"] = "highpass" options["filter_order"] = 8 options["ftype"] = "bessel" diff --git a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py index a571894374..fd835df5c7 100644 --- a/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py +++ b/src/spikeinterface/preprocessing/tests/test_interpolate_bad_channels.py @@ -126,8 +126,10 @@ def test_compare_input_argument_ranges_against_ibl(shanks, p, sigma_um, num_chan # distribute default probe locations across 4 shanks if set rng = np.random.default_rng(seed=None) x = rng.choice(shanks, num_channels) - for idx, __ in enumerate(recording._properties["contact_vector"]): - recording._properties["contact_vector"][idx][1] = x[idx] + probe = recording._probegroup.probes[0] + probe._contact_positions[:, 0] = x + recording._probegroup._build_contact_vector() + recording.set_property("location", recording.get_channel_locations()) # generate random bad channel locations bad_channel_indexes = rng.choice(num_channels, rng.integers(1, int(num_channels / 5)), replace=False) @@ -170,9 +172,12 @@ def test_output_values(): [5, 5, 5, 7, 3], ] # all others equal distance away. # Overwrite the probe information with the new locations + probe = recording._probegroup.probes[0] for idx, (x, y) in enumerate(zip(*new_probe_locs)): - recording._properties["contact_vector"][idx][1] = x - recording._properties["contact_vector"][idx][2] = y + probe._contact_positions[idx, 0] = x + probe._contact_positions[idx, 1] = y + recording._probegroup._build_contact_vector() + recording.set_property("location", recording.get_channel_locations()) # Run interpolation in SI and check the interpolated channel # 0 is a linear combination of other channels @@ -186,8 +191,11 @@ def test_output_values(): # Shift the last channel position so that it is 4 units, rather than 2 # away. Setting sigma_um = p = 1 allows easy calculation of the expected # weights. - recording._properties["contact_vector"][-1][1] = 5 - recording._properties["contact_vector"][-1][2] = 9 + probe = recording._probegroup.probes[0] + probe._contact_positions[-1, 0] = 5 + probe._contact_positions[-1, 1] = 9 + recording._probegroup._build_contact_vector() + recording.set_property("location", recording.get_channel_locations()) expected_weights = np.r_[np.tile(np.exp(-2), 3), np.exp(-4)] expected_weights /= np.sum(expected_weights) diff --git a/src/spikeinterface/preprocessing/zero_channel_pad.py b/src/spikeinterface/preprocessing/zero_channel_pad.py index 45d4809cd8..a84d3bbf64 100644 --- a/src/spikeinterface/preprocessing/zero_channel_pad.py +++ b/src/spikeinterface/preprocessing/zero_channel_pad.py @@ -157,7 +157,7 @@ def __init__(self, recording: BaseRecording, num_channels: int, channel_mapping: "The new mapping cannot exceed total number of channels " "in the zero-chanenl-padded recording." ) else: - if "locations" in recording.get_property_keys() or "contact_vector" in recording.get_property_keys(): + if recording.has_probe() or "location" in recording.get_property_keys(): self.channel_mapping = np.argsort(recording.get_channel_locations()[:, 1]) else: self.channel_mapping = np.arange(recording.get_num_channels()) diff --git a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py index a433eeb643..21a211f099 100644 --- a/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py +++ b/src/spikeinterface/sortingcomponents/motion/motion_interpolation.py @@ -403,13 +403,12 @@ def __init__( dtype_ = fix_dtype(recording, dtype) BasePreprocessor.__init__(self, recording, channel_ids=channel_ids, dtype=dtype_) - if border_mode == "remove_channels": - # change the wiring of the probe - # TODO this is also done in ChannelSliceRecording, this should be done in a common place - contact_vector = self.get_property("contact_vector") - if contact_vector is not None: - contact_vector["device_channel_indices"] = np.arange(len(channel_ids), dtype="int64") - self.set_property("contact_vector", contact_vector) + if border_mode == "remove_channels" and recording.has_probe(): + # slice the probegroup to the retained channels and attach via the canonical path + parent_probegroup = recording.get_probegroup() + sliced_probegroup = parent_probegroup.get_slice(channel_inds) + sliced_probegroup.set_global_device_channel_indices(np.arange(len(channel_ids), dtype="int64")) + self.set_probegroup(sliced_probegroup, in_place=True) # handle manual interpolation_time_bin_centers_s # the case where interpolation_time_bin_size_s is set is handled per-segment below