Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ test_core = [

# for github test : probeinterface and neo from master
# for release we need pypi, so this need to be commented
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
# TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge
"probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle",
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git",

# for slurm jobs,
Expand All @@ -140,7 +141,8 @@ test_extractors = [
"pooch>=1.8.2",
"datalad>=1.0.2",
# Commenting out for release
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
# TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge
"probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle",
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git",
]

Expand Down Expand Up @@ -191,7 +193,8 @@ test = [

# for github test : probeinterface and neo from master
# for release we need pypi, so this need to be commented
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git",
# TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge
"probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle",
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git",

# for slurm jobs
Expand Down Expand Up @@ -220,7 +223,8 @@ docs = [
"huggingface_hub", # For automated curation

# for release we need pypi, so this needs to be commented
"probeinterface @ git+https://github.com/SpikeInterface/probeinterface.git", # We always build from the latest version
# TEMP: point at the probeinterface #425 branch so CI exercises it; revert before merge
"probeinterface @ git+https://github.com/h-mayorquin/probeinterface.git@dense_array_handle",
"neo @ git+https://github.com/NeuralEnsemble/python-neo.git", # We always build from the latest version
]

Expand Down
48 changes: 48 additions & 0 deletions src/spikeinterface/core/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,9 @@ def to_dict(
folder_metadata = Path(folder_metadata).resolve().absolute().relative_to(relative_to)
dump_dict["folder_metadata"] = str(folder_metadata)

if getattr(self, "_probegroup", None) is not None:
dump_dict["probegroup"] = self._probegroup.to_dict(array_as_list=True)

return dump_dict

@staticmethod
Expand Down Expand Up @@ -1155,9 +1158,54 @@ def _load_extractor_from_dict(dic) -> "BaseExtractor":
for k, v in dic["properties"].items():
extractor.set_property(k, v)

if "probegroup" in dic:
from probeinterface import ProbeGroup

probegroup = ProbeGroup.from_dict(dic["probegroup"])
if hasattr(extractor, "set_probegroup"):
extractor.set_probegroup(probegroup, in_place=True)
else:
extractor._probegroup = probegroup
elif "contact_vector" in dic.get("properties", {}):
_restore_probegroup_from_legacy_contact_vector(extractor)

return extractor


def _restore_probegroup_from_legacy_contact_vector(extractor) -> None:
"""
Reconstruct a `ProbeGroup` from the legacy `contact_vector` property.

Recordings saved before the probegroup refactor stored the probe as a structured numpy
array under the `contact_vector` property, with probe-level annotations under a separate
`probes_info` annotation and per-probe planar contours under `probe_{i}_planar_contour`
annotations. This function reconstructs a `ProbeGroup` from those legacy fields, attaches
it via the canonical `set_probegroup` path, and removes the legacy property so the new
and old representations do not coexist on the loaded extractor.
"""
from probeinterface import ProbeGroup

contact_vector_array = extractor.get_property("contact_vector")
probegroup = ProbeGroup.from_numpy(contact_vector_array)

if "probes_info" in extractor.get_annotation_keys():
probes_info = extractor.get_annotation("probes_info")
for probe, probe_info in zip(probegroup.probes, probes_info):
probe.annotations = probe_info

for probe_index, probe in enumerate(probegroup.probes):
contour = extractor._annotations.get(f"probe_{probe_index}_planar_contour")
if contour is not None:
probe.set_planar_contour(contour)

if hasattr(extractor, "set_probegroup"):
extractor.set_probegroup(probegroup, in_place=True)
else:
extractor._probegroup = probegroup

extractor._properties.pop("contact_vector", None)


def _get_class_from_string(class_string):
class_name = class_string.split(".")[-1]
module = ".".join(class_string.split(".")[:-1])
Expand Down
4 changes: 2 additions & 2 deletions src/spikeinterface/core/baserecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ def _save(self, format="binary", verbose: bool = False, **save_kwargs):
else:
raise ValueError(f"format {format} not supported")

if self.get_property("contact_vector") is not None:
if self.has_probe():
probegroup = self.get_probegroup()
cached.set_probegroup(probegroup)

Expand All @@ -409,7 +409,7 @@ def _extra_metadata_from_folder(self, folder):

def _extra_metadata_to_folder(self, folder):
# save probe
if self.get_property("contact_vector") is not None:
if self.has_probe():
probegroup = self.get_probegroup()
write_probeinterface(folder / "probe.json", probegroup)

Expand Down
133 changes: 70 additions & 63 deletions src/spikeinterface/core/baserecordingsnippets.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
from pathlib import Path

import numpy as np
Expand All @@ -19,6 +20,7 @@ def __init__(self, sampling_frequency: float, channel_ids: list[str, int], dtype
BaseExtractor.__init__(self, channel_ids)
self._sampling_frequency = float(sampling_frequency)
self._dtype = np.dtype(dtype)
self._probegroup = None

@property
def channel_ids(self):
Expand Down Expand Up @@ -51,7 +53,7 @@ def has_scaleable_traces(self) -> bool:
return True

def has_probe(self) -> bool:
return "contact_vector" in self.get_property_keys()
return self._probegroup is not None

def has_channel_location(self) -> bool:
return self.has_probe() or "location" in self.get_property_keys()
Expand Down Expand Up @@ -145,21 +147,19 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False):
probe.device_channel_indices is not None for probe in probegroup.probes
), "Probe must have device_channel_indices"

# this is a vector with complex fileds (dataframe like) that handle all contact attr
probe_as_numpy_array = probegroup.to_numpy(complete=True)

# keep only connected contact ( != -1)
keep = probe_as_numpy_array["device_channel_indices"] >= 0
if np.any(~keep):
# identify connected contacts and their channel-order
global_device_channel_indices = probegroup.get_global_device_channel_indices()["device_channel_indices"]
connected_mask = global_device_channel_indices >= 0
if np.any(~connected_mask):
warn("The given probes have unconnected contacts: they are removed")

probe_as_numpy_array = probe_as_numpy_array[keep]

device_channel_indices = probe_as_numpy_array["device_channel_indices"]
order = np.argsort(device_channel_indices)
device_channel_indices = device_channel_indices[order]
connected_contact_indices = np.where(connected_mask)[0]
connected_channel_values = global_device_channel_indices[connected_mask]
order = np.argsort(connected_channel_values)
sorted_contact_indices = connected_contact_indices[order]
device_channel_indices = connected_channel_values[order]

# check TODO: Where did this came from?
# validate indices fit the recording
number_of_device_channel_indices = np.max(list(device_channel_indices) + [0])
if number_of_device_channel_indices >= self.get_num_channels():
error_msg = (
Expand All @@ -173,8 +173,14 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False):
raise ValueError(error_msg)

new_channel_ids = self.get_channel_ids()[device_channel_indices]
probe_as_numpy_array = probe_as_numpy_array[order]
probe_as_numpy_array["device_channel_indices"] = np.arange(probe_as_numpy_array.size, dtype="int64")

# capture ndim before slicing; get_slice with an empty selection yields a probegroup
# with no probes, on which `.ndim` raises
ndim = probegroup.ndim

# slice + reorder probegroup so contact order matches the recording's channel order, and reset wiring to arange
probegroup = probegroup.get_slice(sorted_contact_indices)
probegroup.set_global_device_channel_indices(np.arange(len(device_channel_indices), dtype="int64"))

# create recording : channel slice or clone or self
if in_place:
Expand All @@ -187,23 +193,28 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False):
else:
sub_recording = self.select_channels(new_channel_ids)

# create a vector that handle all contacts in property
sub_recording.set_property("contact_vector", probe_as_numpy_array, ids=None)
sub_recording._probegroup = probegroup

# planar_contour is saved in annotations
for probe_index, probe in enumerate(probegroup.probes):
contour = probe.probe_planar_contour
if contour is not None:
sub_recording.set_annotation(f"probe_{probe_index}_planar_contour", contour, overwrite=True)
# TODO: revisit whether set_probe with a fully unconnected probe should raise
# instead of returning a zero-channel recording. Preserved here for backwards
# compatibility with a test in test_BaseRecording; that test case should be
# peeled into its own named test so this assumption is easy to find and
# discuss when we decide to tighten the behaviour.
if len(device_channel_indices) == 0:
sub_recording.set_property("location", np.zeros((0, ndim), dtype="float64"), ids=None)
sub_recording.set_property("group", np.zeros(0, dtype="int64"), ids=None)
sub_recording.annotate(probes_info=[])
return sub_recording

# duplicate positions to "locations" property
ndim = probegroup.ndim
probe_as_numpy_array = probegroup._contact_vector

# duplicate positions to "location" property so SpikeInterface-level readers keep working
locations = np.zeros((probe_as_numpy_array.size, ndim), dtype="float64")
for i, dim in enumerate(["x", "y", "z"][:ndim]):
locations[:, i] = probe_as_numpy_array[dim]
sub_recording.set_property("location", locations, ids=None)

# handle groups
# derive groups from contact_vector
has_shank_id = "shank_ids" in probe_as_numpy_array.dtype.fields
has_contact_side = "contact_sides" in probe_as_numpy_array.dtype.fields
if group_mode == "auto":
Expand Down Expand Up @@ -232,11 +243,16 @@ def _set_probes(self, probe_or_probegroup, group_mode="auto", in_place=False):
groups[mask] = group
sub_recording.set_property("group", groups, ids=None)

# add probe annotations to recording
probes_info = []
for probe in probegroup.probes:
probes_info.append(probe.annotations)
# TODO discuss backwards compatibility: mirror probe-level annotations and planar
# contours as recording-level annotations so external code that reads these keys
# keeps working. The canonical source is now `probe.annotations` and
# `probe.probe_planar_contour` on the attached probegroup.
probes_info = [probe.annotations for probe in probegroup.probes]
sub_recording.annotate(probes_info=probes_info)
for probe_index, probe in enumerate(probegroup.probes):
contour = probe.probe_planar_contour
if contour is not None:
sub_recording.set_annotation(f"probe_{probe_index}_planar_contour", contour, overwrite=True)

return sub_recording

Expand All @@ -250,30 +266,24 @@ def get_probes(self):
return probegroup.probes

def get_probegroup(self):
arr = self.get_property("contact_vector")
if arr is None:
if self._probegroup is None:
# Backwards-compat fallback: pre-migration get_probegroup synthesised a dummy
# probe from the "location" property when no probe had been attached. Callers
# (e.g. sparsity.py) rely on this for recordings that have locations but no
# probe.
positions = self.get_property("location")
if positions is None:
raise ValueError("There is no Probe attached to this recording. Use set_probe(...) to attach one.")
else:
warn("There is no Probe attached to this recording. Creating a dummy one with contact positions")
probe = self.create_dummy_probe_from_locations(positions)
# probe.create_auto_shape()
probegroup = ProbeGroup()
probegroup.add_probe(probe)
else:
probegroup = ProbeGroup.from_numpy(arr)

if "probes_info" in self.get_annotation_keys():
probes_info = self.get_annotation("probes_info")
for probe, probe_info in zip(probegroup.probes, probes_info):
probe.annotations = probe_info

for probe_index, probe in enumerate(probegroup.probes):
contour = self.get_annotation(f"probe_{probe_index}_planar_contour")
if contour is not None:
probe.set_planar_contour(contour)
return probegroup
warn("There is no Probe attached to this recording. Creating a dummy one with contact positions")
probe = self.create_dummy_probe_from_locations(positions)
pg = ProbeGroup()
pg.add_probe(probe)
return copy.deepcopy(pg)
# Return a deepcopy for backwards compatibility: pre-migration `main` reconstructed
# a fresh `ProbeGroup` from the stored structured array on each call, so external
# callers relied on value semantics. Handing out the live `_probegroup` would be a
# silent behavioural change.
return copy.deepcopy(self._probegroup)

def _extra_metadata_from_folder(self, folder):
# load probe
Expand All @@ -284,7 +294,7 @@ def _extra_metadata_from_folder(self, folder):

def _extra_metadata_to_folder(self, folder):
# save probe
if self.get_property("contact_vector") is not None:
if self.has_probe():
probegroup = self.get_probegroup()
write_probeinterface(folder / "probe.json", probegroup)

Expand Down Expand Up @@ -341,29 +351,26 @@ def set_dummy_probe_from_locations(self, locations, shape="circle", shape_params
self.set_probe(probe, in_place=True)

def set_channel_locations(self, locations, channel_ids=None):
if self.get_property("contact_vector") is not None:
if self.has_probe():
raise ValueError("set_channel_locations(..) destroys the probe description, prefer _set_probes(..)")
self.set_property("location", locations, ids=channel_ids)

def get_channel_locations(self, channel_ids=None, axes: str = "xy") -> np.ndarray:
if channel_ids is None:
channel_ids = self.get_channel_ids()
channel_indices = self.ids_to_indices(channel_ids)
contact_vector = self.get_property("contact_vector")
if contact_vector is not None:
# here we bypass the probe reconstruction so this works both for probe and probegroup
if self.has_probe():
contact_vector = self._probegroup._contact_vector
ndim = len(axes)
all_positions = np.zeros((contact_vector.size, ndim), dtype="float64")
for i, dim in enumerate(axes):
all_positions[:, i] = contact_vector[dim]
positions = all_positions[channel_indices]
return positions
else:
locations = self.get_property("location")
if locations is None:
raise Exception("There are no channel locations")
locations = np.asarray(locations)[channel_indices]
return select_axes(locations, axes)
return all_positions[channel_indices]
locations = self.get_property("location")
if locations is None:
raise Exception("There are no channel locations")
locations = np.asarray(locations)[channel_indices]
return select_axes(locations, axes)

def has_3d_locations(self) -> bool:
return self.get_property("location").shape[1] == 3
Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def _save(self, format="npy", **save_kwargs):
else:
raise ValueError(f"format {format} not supported")

if self.get_property("contact_vector") is not None:
if self.has_probe():
probegroup = self.get_probegroup()
cached.set_probegroup(probegroup)

Expand Down
23 changes: 20 additions & 3 deletions src/spikeinterface/core/channelsaggregationrecording.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import numpy as np

from probeinterface import Probe, ProbeGroup

from .baserecording import BaseRecording, BaseRecordingSegment


Expand Down Expand Up @@ -90,11 +92,26 @@ def __init__(self, recording_list_or_dict=None, renamed_channel_ids=None, record
break

for prop_name, prop_values in property_dict.items():
if prop_name == "contact_vector":
# remap device channel indices correctly
prop_values["device_channel_indices"] = np.arange(self.get_num_channels())
self.set_property(key=prop_name, values=prop_values)

# split_by resets each child probe's device_channel_indices, so the information
# of which contact was connected to which channel of the parent is lost by the
# time we aggregate. We rebuild a globally-unique wiring via per-probe offsets
# and skip set_probegroup because children also share contact positions.
if all(rec.has_probe() for rec in recording_list):
aggregated_probegroup = ProbeGroup()
offset = 0
for rec in recording_list:
for probe in rec.get_probegroup().probes:
# round-trip through to_dict/from_dict because Probe.copy() drops
# contact_ids and annotations (probeinterface #421)
probe_copy = Probe.from_dict(probe.to_dict(array_as_list=False))
n = probe_copy.get_contact_count()
probe_copy.set_device_channel_indices(np.arange(offset, offset + n, dtype="int64"))
aggregated_probegroup.add_probe(probe_copy)
offset += n
self._probegroup = aggregated_probegroup

# if locations are present, check that they are all different!
if "location" in self.get_property_keys():
location_tuple = [tuple(loc) for loc in self.get_property("location")]
Expand Down
Loading
Loading