Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 4 additions & 7 deletions src/probeinterface/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,10 +202,7 @@ def read_BIDS_probe(folder: str | Path, prefix: str | None = None) -> ProbeGroup

# create probe object and register with probegroup
probe = Probe.from_dataframe(df=df_probe)
probe.annotate(probe_id=probe_id)

probes[str(probe_id)] = probe
probegroup.add_probe(probe)

ignore_annotations = [
"probe_ids",
Expand Down Expand Up @@ -294,6 +291,10 @@ def read_BIDS_probe(folder: str | Path, prefix: str | None = None) -> ProbeGroup

probe.annotate(**{contact_param: value_list})

# Step 5: add probes to probegroup
for probe in probes.values():
probegroup.add_probe(probe)

return probegroup


Expand Down Expand Up @@ -337,10 +338,6 @@ def write_BIDS_probe(folder: str | Path, probe_or_probegroup: Probe | ProbeGroup

# Step 1: GENERATION OF PROBE.TSV
# ensure required keys (probe_id, probe_type) are present

if any("probe_id" not in p.annotations for p in probes):
probegroup.auto_generate_probe_ids()

for probe in probes:
if "probe_id" not in probe.annotations:
raise ValueError(
Expand Down
25 changes: 20 additions & 5 deletions src/probeinterface/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def __init__(
# vertices for the shape of the probe
self.probe_planar_contour = None

# the Probe can belong to a ProbeGroup
self._probe_group = None

# This handles the shank id per contact
# If None then one shank only
self._shank_ids = None
Expand Down Expand Up @@ -129,9 +132,6 @@ def __init__(
# same idea but handle in vector way for contacts
self.contact_annotations = dict()

# the Probe can belong to a ProbeGroup
self._probe_group = None

@property
def contact_positions(self):
"""The position of the center for each contact"""
Expand Down Expand Up @@ -260,6 +260,11 @@ def annotate(self, **kwargs):
----------
**kwargs : list of keyword arguments to add to the annotations (e.g., brain_area="CA1")
"""
if self._probe_group is not None:
raise ValueError(
"You cannot annotate a probe that belongs to a ProbeGroup. "
"Annotate the probe before adding it to the ProbeGroup or use the `ProbeGroup.annotate_probe` method."
)
self.annotations.update(kwargs)
self.check_annotations()

Expand All @@ -271,6 +276,11 @@ def annotate_contacts(self, **kwargs):
----------
**kwargs : list of keyword arguments to add to the annotations (e.g., quality=["good", "bad", ...])
"""
if self._probe_group is not None:
raise ValueError(
"You cannot annotate contacts of a probe that belongs to a ProbeGroup. "
"Annotate the probe before adding it to the ProbeGroup instead."
)
n = self.get_contact_count()
for k, values in kwargs.items():
assert len(values) == n, (
Expand Down Expand Up @@ -977,7 +987,7 @@ def from_dict(d: dict) -> "Probe":

return probe

def to_numpy(self, complete: bool = False) -> np.ndarray:
def to_numpy(self, complete: bool = False, probe_index: int | None = None) -> np.ndarray:
"""
Export the probe to a numpy structured array.
This array handles all contact attributes.
Expand Down Expand Up @@ -1035,7 +1045,10 @@ def to_numpy(self, complete: bool = False) -> np.ndarray:
"""

# First define the dtype
dtype = [("x", "float64"), ("y", "float64")]
dtype = []
if probe_index is not None:
dtype = [("probe_index", "int64")]
dtype += [("x", "float64"), ("y", "float64")]
if self.ndim == 3:
dtype += [("z", "float64")]

Expand Down Expand Up @@ -1070,6 +1083,8 @@ def to_numpy(self, complete: bool = False) -> np.ndarray:

# Then add the data to the structured array
arr = np.zeros(self.get_contact_count(), dtype=dtype)
if probe_index is not None:
arr["probe_index"] = probe_index
arr["x"] = self.contact_positions[:, 0]
arr["y"] = self.contact_positions[:, 1]
if self.ndim == 3:
Expand Down
Loading
Loading