diff --git a/src/probeinterface/io.py b/src/probeinterface/io.py index ad1a60a..245f3a5 100644 --- a/src/probeinterface/io.py +++ b/src/probeinterface/io.py @@ -202,10 +202,7 @@ def read_BIDS_probe(folder: str | Path, prefix: str | None = None) -> ProbeGroup # create probe object and register with probegroup probe = Probe.from_dataframe(df=df_probe) - probe.annotate(probe_id=probe_id) - probes[str(probe_id)] = probe - probegroup.add_probe(probe) ignore_annotations = [ "probe_ids", @@ -294,6 +291,10 @@ def read_BIDS_probe(folder: str | Path, prefix: str | None = None) -> ProbeGroup probe.annotate(**{contact_param: value_list}) + # Step 5: add probes to probegroup + for probe in probes.values(): + probegroup.add_probe(probe) + return probegroup @@ -337,10 +338,6 @@ def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup # Step 1: GENERATION OF PROBE.TSV # ensure required keys (probe_id, probe_type) are present - - if any("probe_id" not in p.annotations for p in probes): - probegroup.auto_generate_probe_ids() - for probe in probes: if "probe_id" not in probe.annotations: raise ValueError( diff --git a/src/probeinterface/probe.py b/src/probeinterface/probe.py index c19ddd6..5b838dc 100644 --- a/src/probeinterface/probe.py +++ b/src/probeinterface/probe.py @@ -99,6 +99,9 @@ def __init__( # vertices for the shape of the probe self.probe_planar_contour = None + # the Probe can belong to a ProbeGroup + self._probe_group = None + # This handles the shank id per contact # If None then one shank only self._shank_ids = None @@ -129,9 +132,6 @@ def __init__( # same idea but handle in vector way for contacts self.contact_annotations = dict() - # the Probe can belong to a ProbeGroup - self._probe_group = None - @property def contact_positions(self): """The position of the center for each contact""" @@ -260,6 +260,11 @@ def annotate(self, **kwargs): ---------- **kwargs : list of keyword arguments to add to the annotations (e.g., brain_area="CA1") """ + if self._probe_group is not None: + raise ValueError( + "You cannot annotate a probe that belongs to a ProbeGroup. " + "Annotate the probe before adding it to the ProbeGroup or use the `ProbeGroup.annotate_probe` method." + ) self.annotations.update(kwargs) self.check_annotations() @@ -271,6 +276,11 @@ def annotate_contacts(self, **kwargs): ---------- **kwargs : list of keyword arguments to add to the annotations (e.g., quality=["good", "bad", ...]) """ + if self._probe_group is not None: + raise ValueError( + "You cannot annotate contacts of a probe that belongs to a ProbeGroup. " + "Annotate the probe before adding it to the ProbeGroup instead." + ) n = self.get_contact_count() for k, values in kwargs.items(): assert len(values) == n, ( @@ -977,7 +987,7 @@ def from_dict(d: dict) -> "Probe": return probe - def to_numpy(self, complete: bool = False) -> np.ndarray: + def to_numpy(self, complete: bool = False, probe_index: int | None = None) -> np.ndarray: """ Export the probe to a numpy structured array. This array handles all contact attributes. @@ -1035,7 +1045,10 @@ def to_numpy(self, complete: bool = False) -> np.ndarray: """ # First define the dtype - dtype = [("x", "float64"), ("y", "float64")] + dtype = [] + if probe_index is not None: + dtype = [("probe_index", "int64")] + dtype += [("x", "float64"), ("y", "float64")] if self.ndim == 3: dtype += [("z", "float64")] @@ -1070,6 +1083,8 @@ def to_numpy(self, complete: bool = False) -> np.ndarray: # Then add the data to the structured array arr = np.zeros(self.get_contact_count(), dtype=dtype) + if probe_index is not None: + arr["probe_index"] = probe_index arr["x"] = self.contact_positions[:, 0] arr["y"] = self.contact_positions[:, 1] if self.ndim == 3: diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index d42906a..0eb3d69 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -1,3 +1,4 @@ +from copy import deepcopy import numpy as np from .utils import generate_unique_ids from .probe import Probe @@ -12,7 +13,60 @@ class ProbeGroup: """ def __init__(self): - self.probes = [] + self._contact_array = None + self._num_probes = 0 + self._probe_contours = [] + self._annotations = [] + + @property + def num_probes(self) -> int: + """ + Get the number of probes in the ProbeGroup + + Returns + ------- + num_probes: int + The number of probes in the ProbeGroup + """ + return int(self._num_probes) + + @property + def probes(self) -> list[Probe]: + """ + Get the list of probes in the ProbeGroup + + Returns + ------- + probes: list of Probe + The list of probes in the ProbeGroup + """ + probes = [] + for probe_index in range(self._num_probes): + probe_mask = self._contact_array["probe_index"] == probe_index + probe_array = self._contact_array[probe_mask] + probe = Probe.from_numpy(probe_array) + # add annotations and contour + probe.annotations = self._annotations[probe_index] + probe.probe_planar_contour = self._probe_contours[probe_index] + probe._probe_group = self + probes.append(probe) + return probes + + def annotate_probe(self, probe_index: int, **annotations) -> None: + """ + Add annotations to a specific probe in the ProbeGroup + + Parameters + ---------- + probe_index: int + The index of the probe to annotate + **annotations: + The annotations to add to the probe + + """ + if probe_index >= self._num_probes: + raise ValueError(f"probe_index {probe_index} is out of bounds for num_probes {self._num_probes}") + self._annotations[probe_index].update(annotations) def add_probe(self, probe: Probe) -> None: """ @@ -27,8 +81,44 @@ def add_probe(self, probe: Probe) -> None: if len(self.probes) > 0: self._check_compatible(probe) - self.probes.append(probe) + probe_array = probe.to_numpy(complete=True, probe_index=self._num_probes) + probe_dtype = probe_array.dtype + if probe.contact_ids is None: + count = probe.get_contact_count() + width = len(str(count - 1)) # or count, depending on whether you want inclusive + contact_ids = [f"{i:0{width}d}" for i in range(count)] + probe_array["contact_ids"] = contact_ids + if self._contact_array is None: + self._contact_array = probe_array + else: + # Handle the case where the new probe has a different dtype than the existing contact array + # e.g., one probe has square contacts and the other has circular contacts, so different shape parameters + existing_dtype = self._contact_array.dtype + if existing_dtype != probe_dtype: + fields_to_add = [f for f in probe_dtype.fields if f not in existing_dtype.fields] + new_dtype = probe_dtype + + # Create a new dtype that is the union of the existing and new dtypes + new_fields = list(existing_dtype.descr) + [ + f for f in probe_dtype.descr if f[0] not in existing_dtype.fields + ] + new_dtype = np.dtype(new_fields) + # Create a new array with the new dtype and copy existing data + new_contact_array = np.zeros(self._contact_array.shape, dtype=new_dtype) + new_probe_array = np.zeros(probe_array.shape, dtype=new_dtype) + for name in existing_dtype.names: + new_contact_array[name] = self._contact_array[name] + for name in probe_dtype.names: + new_probe_array[name] = probe_array[name] + self._contact_array = new_contact_array + probe_array = new_probe_array + self._contact_array = np.concatenate((self._contact_array, probe_array), axis=0) + self._probe_contours.append(probe.probe_planar_contour) + annotations = probe.annotations + annotations["probe_id"] = probe.annotations.get("probe_id", f"probe_{self._num_probes}") + self._annotations.append(annotations) probe._probe_group = self + self._num_probes += 1 def _check_compatible(self, probe: Probe) -> None: if probe._probe_group is not None: @@ -42,9 +132,7 @@ def _check_compatible(self, probe: Probe) -> None: ) # check global channel maps - self.probes.append(probe) - self.check_global_device_wiring_and_ids() - self.probes = self.probes[:-1] + self.check_global_device_wiring_and_ids(new_device_channel_indices=probe.device_channel_indices) @property def ndim(self) -> int: @@ -60,10 +148,10 @@ def copy(self) -> "ProbeGroup": A copy of the ProbeGroup """ copy = ProbeGroup() - for probe in self.probes: - copy.add_probe(probe.copy()) - global_device_channel_indices = self.get_global_device_channel_indices()["device_channel_indices"] - copy.set_global_device_channel_indices(global_device_channel_indices) + copy._num_probes = self._num_probes + copy._contact_array = self._contact_array.copy() + copy._probe_contours = deepcopy(self._probe_contours) + copy._annotations = deepcopy(self._annotations) return copy def get_contact_count(self) -> int: @@ -88,33 +176,28 @@ def to_numpy(self, complete: bool = False) -> np.ndarray: If True, export complete information about the probegroup including contact_plane_axes/si_units/device_channel_indices """ - - fields = [] - probe_arr = [] - - # loop over probes to get all fields - dtype = [("probe_index", "int64")] - fields = [] - for probe_index, probe in enumerate(self.probes): - arr = probe.to_numpy(complete=complete) - probe_arr.append(arr) - for k in arr.dtype.fields: - if k not in fields: - fields.append(k) - dtype += [(k, arr.dtype.fields[k][0])] - - pg_arr = [] - for probe_index, probe in enumerate(self.probes): - arr = probe_arr[probe_index] - arr_ext = np.zeros(probe.get_contact_count(), dtype=dtype) - arr_ext["probe_index"] = probe_index - for k in fields: - if k in arr.dtype.fields: - arr_ext[k] = arr[k] - pg_arr.append(arr_ext) - - pg_arr = np.concatenate(pg_arr, axis=0) - return pg_arr + if complete: + return self._contact_array.copy() + else: + # Remove fields that are not in the default export + all_probe_fields = [] + for probe_index in range(self._num_probes): + probe_fields = self.probes[probe_index].to_numpy(complete=False, probe_index=0).dtype.fields + for f in probe_fields: + if f not in all_probe_fields: + all_probe_fields.append(f) + probe_fields = all_probe_fields + + fields_to_remove = [f for f in self._contact_array.dtype.names if f not in probe_fields] + dtype = [ + (name, self._contact_array.dtype.fields[name][0]) + for name in self._contact_array.dtype.names + if name not in fields_to_remove + ] + arr = np.zeros(self._contact_array.shape, dtype=dtype) + for name in arr.dtype.names: + arr[name] = self._contact_array[name] + return arr @staticmethod def from_numpy(arr: np.ndarray) -> "ProbeGroup": @@ -135,10 +218,11 @@ def from_numpy(arr: np.ndarray) -> "ProbeGroup": probes_indices = np.unique(arr["probe_index"]) probegroup = ProbeGroup() + probegroup._contact_array = arr.copy() for probe_index in probes_indices: - mask = arr["probe_index"] == probe_index - probe = Probe.from_numpy(arr[mask]) - probegroup.add_probe(probe) + probegroup._probe_contours.append(None) + probegroup._annotations.append({}) + probegroup._num_probes = len(probes_indices) return probegroup def to_dataframe(self, complete: bool = False) -> "pandas.DataFrame": @@ -221,9 +305,8 @@ def get_global_device_channel_indices(self) -> np.ndarray: """ total_chan = self.get_contact_count() channels = np.zeros(total_chan, dtype=[("probe_index", "int64"), ("device_channel_indices", "int64")]) - arr = self.to_numpy(complete=True) - channels["probe_index"] = arr["probe_index"] - channels["device_channel_indices"] = arr["device_channel_indices"] + channels["probe_index"] = self._contact_array["probe_index"] + channels["device_channel_indices"] = self._contact_array["device_channel_indices"] return channels def set_global_device_channel_indices(self, channels: np.ndarray | list) -> None: @@ -240,18 +323,7 @@ def set_global_device_channel_indices(self, channels: np.ndarray | list) -> None raise ValueError( f"Wrong channels size {channels.size} for the number of channels {self.get_contact_count()}" ) - - # first reset previous indices - for i, probe in enumerate(self.probes): - n = probe.get_contact_count() - probe.set_device_channel_indices([-1] * n) - - # then set new indices - ind = 0 - for i, probe in enumerate(self.probes): - n = probe.get_contact_count() - probe.set_device_channel_indices(channels[ind : ind + n]) - ind += n + self._contact_array["device_channel_indices"] = channels def get_global_contact_ids(self) -> np.ndarray: """ @@ -262,7 +334,7 @@ def get_global_contact_ids(self) -> np.ndarray: contact_ids: np.ndarray An array of the contaact ids across all probes """ - contact_ids = self.to_numpy(complete=True)["contact_ids"] + contact_ids = self._contact_array["contact_ids"] return contact_ids def get_global_contact_positions(self) -> np.ndarray: @@ -322,79 +394,33 @@ def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup": if len(selection_indices) == 0: return ProbeGroup() - # Map selection to indices of individual probes - ind = 0 - sliced_probes = [] - for probe in self.probes: - n = probe.get_contact_count() - probe_limits = (ind, ind + n) - ind += n - - probe_selection_indices = selection_indices[ - (selection_indices >= probe_limits[0]) & (selection_indices < probe_limits[1]) - ] - if len(probe_selection_indices) == 0: - continue - sliced_probe = probe.get_slice(probe_selection_indices - probe_limits[0]) - sliced_probes.append(sliced_probe) + full_contact_array = self._contact_array + sliced_contact_array = full_contact_array[selection_indices] + probe_indices = np.unique(sliced_contact_array["probe_index"]) + new_probe_contours = [] + new_annotations = [] + for new_probe_index, old_probe_index in enumerate(probe_indices): + sliced_contact_array["probe_index"][ + sliced_contact_array["probe_index"] == old_probe_index + ] = new_probe_index + new_probe_contours.append(self._probe_contours[old_probe_index]) + new_annotations.append(self._annotations[old_probe_index]) sliced_probe_group = ProbeGroup() - for probe in sliced_probes: - sliced_probe_group.add_probe(probe) - + sliced_probe_group._contact_array = sliced_contact_array + sliced_probe_group._num_probes = len(probe_indices) + sliced_probe_group._probe_contours = new_probe_contours + sliced_probe_group._annotations = new_annotations return sliced_probe_group - def check_global_device_wiring_and_ids(self) -> None: + def check_global_device_wiring_and_ids(self, new_device_channel_indices: np.ndarray | None = None) -> None: # check unique device_channel_indices for !=-1 - chans = self.get_global_device_channel_indices() - keep = chans["device_channel_indices"] >= 0 - valid_chans = chans[keep]["device_channel_indices"] + chans = self.get_global_device_channel_indices()["device_channel_indices"] + if new_device_channel_indices is not None: + chans = np.concatenate([chans, new_device_channel_indices]) + + keep = chans >= 0 + valid_chans = chans[keep] if valid_chans.size != np.unique(valid_chans).size: raise ValueError("channel device indices are not unique across probes") - - def auto_generate_probe_ids(self, *args, **kwargs) -> None: - """ - Annotate all probes with unique probe_id values. - - Parameters - ---------- - *args: will be forwarded to `probeinterface.utils.generate_unique_ids` - **kwargs: will be forwarded to - `probeinterface.utils.generate_unique_ids` - """ - - if any("probe_id" in p.annotations for p in self.probes): - raise ValueError("Probe already has a `probe_id` annotation.") - - if not args: - args = 1e7, 1e8 - # 3rd argument has to be the number of probes - args = args[:2] + (len(self.probes),) - - # creating unique probe ids in case probes do not have any yet - probe_ids = generate_unique_ids(*args, **kwargs).astype(str) - for pid, probe in enumerate(self.probes): - probe.annotate(probe_id=probe_ids[pid]) - - def auto_generate_contact_ids(self, *args, **kwargs) -> None: - """ - Annotate all contacts with unique contact_id values. - - Parameters - ---------- - *args: will be forwarded to `probeinterface.utils.generate_unique_ids` - **kwargs: will be forwarded to - `probeinterface.utils.generate_unique_ids` - """ - - if not args: - args = 1e7, 1e8 - # 3rd argument has to be the number of probes - args = args[:2] + (self.get_contact_count(),) - - contact_ids = generate_unique_ids(*args, **kwargs).astype(str) - - for probe in self.probes: - el_ids, contact_ids = np.split(contact_ids, [probe.get_contact_count()]) - probe.set_contact_ids(el_ids) diff --git a/tests/test_io/test_io.py b/tests/test_io/test_io.py index 9a69d3b..ab30952 100644 --- a/tests/test_io/test_io.py +++ b/tests/test_io/test_io.py @@ -69,22 +69,21 @@ def test_BIDS_format(tmp_path): # add custom probe type annotation to be # compatible with BIDS specifications - for probe in probegroup.probes: - probe.annotate(type="laminar") - - # add unique contact ids to be compatible - # with BIDS specifications - n_els = sum([p.get_contact_count() for p in probegroup.probes]) - # using np.random.choice to ensure uniqueness of contact ids - el_ids = np.random.choice(np.arange(1e4, 1e5, dtype="int"), replace=False, size=n_els).astype(str) - for probe in probegroup.probes: - probe_el_ids, el_ids = np.split(el_ids, [probe.get_contact_count()]) - probe.set_contact_ids(probe_el_ids) - - # switch to more generic dtype for shank_ids - if probe.shank_ids is not None: - probe.set_shank_ids(probe.shank_ids.astype(str)) - + for probe_index in range(probegroup.num_probes): + probegroup.annotate_probe(probe_index, type="laminar") + + # # add unique contact ids to be compatible + # # with BIDS specifications + # n_els = sum([p.get_contact_count() for p in probegroup.probes]) + # # using np.random.choice to ensure uniqueness of contact ids + # el_ids = np.random.choice(np.arange(1e4, 1e5, dtype="int"), replace=False, size=n_els).astype(str) + # for probe in probegroup.probes: + # probe_el_ids, el_ids = np.split(el_ids, [probe.get_contact_count()]) + # probe.set_contact_ids(probe_el_ids) + + # # switch to more generic dtype for shank_ids + # if probe.shank_ids is not None: + # probe.set_shank_ids(probe.shank_ids.astype(str)) write_BIDS_probe(folder_path, probegroup) probegroup_read = read_BIDS_probe(folder_path) @@ -115,8 +114,8 @@ def test_BIDS_format(tmp_path): assert shape_params == probe_read.contact_shape_params[t][sid] for i in range(len(probe_orig.contact_positions)): assert all(probe_orig.contact_positions[i] == probe_read.contact_positions[t][i]) - for i in range(len(probe.contact_plane_axes)): - for dim in range(len(probe.contact_plane_axes[i])): + for i in range(len(probe_orig.contact_plane_axes)): + for dim in range(len(probe_orig.contact_plane_axes[i])): assert all(probe_orig.contact_plane_axes[i][dim] == probe_read.contact_plane_axes[t][i][dim]) diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index c942190..7fbf606 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -37,11 +37,9 @@ def test_probegroup(probegroup): other = ProbeGroup.from_dict(d) # checking automatic generation of ids with new dummy probes - probegroup.probes = [] + probegroup = ProbeGroup() for i in range(3): probegroup.add_probe(generate_dummy_probe()) - probegroup.auto_generate_contact_ids() - probegroup.auto_generate_probe_ids() for p in probegroup.probes: assert p.contact_ids is not None @@ -179,13 +177,6 @@ def test_copy_preserves_device_channel_indices(probegroup): ) -def test_copy_does_not_preserve_contact_ids(probegroup): - """Probe.copy() intentionally does not copy contact_ids.""" - pg_copy = probegroup.copy() - # All contact_ids should be empty strings after copy - assert all(cid == "" for cid in pg_copy.get_global_contact_ids()) - - def test_copy_is_independent(probegroup): """Mutating the copy must not affect the original.""" original_positions = probegroup.probes[0].contact_positions.copy() @@ -254,6 +245,13 @@ def test_get_slice_all_contacts(probegroup): ) +def test_reset_of_probe_indexing(probegroup): + """Test that after slicing, the probe indexing is reset to 0..N-1.""" + indices = np.arange(probegroup.probes[0].get_contact_count() + 2) # some contacts from probe 0 and 1 + sliced = probegroup.get_slice(indices) + assert np.all(sliced._contact_array["probe_index"]) in [0, 1] # should be reset to 0 and 1 + + if __name__ == "__main__": test_probegroup() # ~ test_probegroup_3d()