Skip to content
113 changes: 104 additions & 9 deletions src/probeinterface/probegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class ProbeGroup:
def __init__(self):
self.probes = []

def add_probe(self, probe: Probe):
def add_probe(self, probe: Probe) -> None:
"""
Add an additional probe to the ProbeGroup

Expand All @@ -30,7 +30,7 @@ def add_probe(self, probe: Probe):
self.probes.append(probe)
probe._probe_group = self

def _check_compatible(self, probe: Probe):
def _check_compatible(self, probe: Probe) -> None:
if probe._probe_group is not None:
raise ValueError(
"This probe is already attached to another ProbeGroup. Use probe.copy() to attach it to another ProbeGroup"
Expand All @@ -47,9 +47,25 @@ def _check_compatible(self, probe: Probe):
self.probes = self.probes[:-1]

@property
def ndim(self):
def ndim(self) -> int:
return self.probes[0].ndim

def copy(self) -> "ProbeGroup":
"""
Create a copy of the ProbeGroup

Returns
-------
copy: ProbeGroup
A copy of the ProbeGroup
"""
copy = ProbeGroup()
for probe in self.probes:
copy.add_probe(probe.copy())
global_device_channel_indices = self.get_global_device_channel_indices()["device_channel_indices"]
copy.set_global_device_channel_indices(global_device_channel_indices)
return copy

def get_contact_count(self) -> int:
"""
Total number of channels.
Expand Down Expand Up @@ -147,7 +163,7 @@ def to_dataframe(self, complete: bool = False) -> "pandas.DataFrame":
df.index = np.arange(df.shape[0], dtype="int64")
return df

def to_dict(self, array_as_list: bool = False):
def to_dict(self, array_as_list: bool = False) -> dict:
"""Create a dictionary of all necessary attributes.

Parameters
Expand All @@ -168,7 +184,7 @@ def to_dict(self, array_as_list: bool = False):
return d

@staticmethod
def from_dict(d: dict):
def from_dict(d: dict) -> "ProbeGroup":
"""Instantiate a ProbeGroup from a dictionary

Parameters
Expand Down Expand Up @@ -210,7 +226,7 @@ def get_global_device_channel_indices(self) -> np.ndarray:
channels["device_channel_indices"] = arr["device_channel_indices"]
return channels

def set_global_device_channel_indices(self, channels: np.ndarray | list):
def set_global_device_channel_indices(self, channels: np.ndarray | list) -> None:
"""
Set global indices for all probes

Expand Down Expand Up @@ -249,7 +265,86 @@ def get_global_contact_ids(self) -> np.ndarray:
contact_ids = self.to_numpy(complete=True)["contact_ids"]
return contact_ids

def check_global_device_wiring_and_ids(self):
def get_global_contact_positions(self) -> np.ndarray:
"""
Gets all contact positions concatenated across probes

Returns
-------
contact_positions: np.ndarray
An array of the contact positions across all probes
"""
contact_positions = np.vstack([probe.contact_positions for probe in self.probes])
return contact_positions

def get_slice(self, selection: np.ndarray[bool | int]) -> "ProbeGroup":
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we strongly need to discuss this behavior.
Does the slice is probe by probe concatenated or is the slice is on channel_devcide_index ordered ?
This is very very important when handling several probes!!! sometimes channel indices are interleaved!!
And so the result will not be the same.

"""
Get a copy of the ProbeGroup with a sub selection of contacts.

Selection can be boolean or by index

Parameters
----------
selection : np.array of bool or int (for index)
Either an np.array of bool or for desired selection of contacts
or the indices of the desired contacts

Returns
-------
sliced_probe_group: ProbeGroup
The sliced probe group

"""

n = self.get_contact_count()

selection = np.asarray(selection)
if selection.dtype == "bool":
assert selection.shape == (
n,
), f"if array of bool given it must be the same size as the number of contacts {selection.shape} != {n}"
(selection_indices,) = np.nonzero(selection)
elif selection.dtype.kind == "i":
assert np.unique(selection).size == selection.size
if len(selection) > 0:
assert (
0 <= np.min(selection) < n
), f"An index within your selection is out of bounds {np.min(selection)}"
assert (
0 <= np.max(selection) < n
), f"An index within your selection is out of bounds {np.max(selection)}"
selection_indices = selection
else:
selection_indices = []
else:
raise TypeError(f"selection must be bool array or int array, not of type: {type(selection)}")

if len(selection_indices) == 0:
return ProbeGroup()

# Map selection to indices of individual probes
ind = 0
sliced_probes = []
for probe in self.probes:
n = probe.get_contact_count()
probe_limits = (ind, ind + n)
ind += n

probe_selection_indices = selection_indices[
(selection_indices >= probe_limits[0]) & (selection_indices < probe_limits[1])
]
if len(probe_selection_indices) == 0:
continue
sliced_probe = probe.get_slice(probe_selection_indices - probe_limits[0])
sliced_probes.append(sliced_probe)

sliced_probe_group = ProbeGroup()
for probe in sliced_probes:
sliced_probe_group.add_probe(probe)

return sliced_probe_group

def check_global_device_wiring_and_ids(self) -> None:
# check unique device_channel_indices for !=-1
chans = self.get_global_device_channel_indices()
keep = chans["device_channel_indices"] >= 0
Expand All @@ -258,7 +353,7 @@ def check_global_device_wiring_and_ids(self):
if valid_chans.size != np.unique(valid_chans).size:
raise ValueError("channel device indices are not unique across probes")

def auto_generate_probe_ids(self, *args, **kwargs):
def auto_generate_probe_ids(self, *args, **kwargs) -> None:
"""
Annotate all probes with unique probe_id values.

Expand All @@ -282,7 +377,7 @@ def auto_generate_probe_ids(self, *args, **kwargs):
for pid, probe in enumerate(self.probes):
probe.annotate(probe_id=probe_ids[pid])

def auto_generate_contact_ids(self, *args, **kwargs):
def auto_generate_contact_ids(self, *args, **kwargs) -> None:
"""
Annotate all contacts with unique contact_id values.

Expand Down
168 changes: 153 additions & 15 deletions tests/test_probegroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,27 @@
import numpy as np


def test_probegroup():
@pytest.fixture
def probegroup():
"""Fixture: a ProbeGroup with 3 probes, each with device channel indices set."""
probegroup = ProbeGroup()

nchan = 0
for i in range(3):
probe = generate_dummy_probe()
probe.move([i * 100, i * 80])
n = probe.get_contact_count()
probe.set_device_channel_indices(np.arange(n)[::-1] + nchan)
shank_ids = np.ones(n)
shank_ids[: n // 2] *= i * 2
shank_ids[n // 2 :] *= i * 2 + 1
probe.set_shank_ids(shank_ids)
probe.set_device_channel_indices(np.arange(n) + nchan)
probegroup.add_probe(probe)

nchan += n
return probegroup


def test_probegroup(probegroup):
indices = probegroup.get_global_device_channel_indices()

ids = probegroup.get_global_contact_ids()

df = probegroup.to_dataframe()
# ~ print(df['global_contact_ids'])

arr = probegroup.to_numpy(complete=False)
other = ProbeGroup.from_numpy(arr)
Expand All @@ -38,12 +36,6 @@ def test_probegroup():
d = probegroup.to_dict()
other = ProbeGroup.from_dict(d)

# ~ from probeinterface.plotting import plot_probe_group, plot_probe
# ~ import matplotlib.pyplot as plt
# ~ plot_probe_group(probegroup)
# ~ plot_probe_group(other)
# ~ plt.show()

# checking automatic generation of ids with new dummy probes
probegroup.probes = []
for i in range(3):
Expand Down Expand Up @@ -116,6 +108,152 @@ def test_set_contact_ids_rejects_wrong_size():
probe.set_contact_ids(["a", "b", "c"])


# ── get_global_contact_positions() tests ────────────────────────────────────


def test_get_global_contact_positions_shape(probegroup):
pos = probegroup.get_global_contact_positions()
assert pos.shape == (probegroup.get_contact_count(), probegroup.ndim)


def test_get_global_contact_positions_matches_per_probe(probegroup):
pos = probegroup.get_global_contact_positions()
offset = 0
for probe in probegroup.probes:
n = probe.get_contact_count()
np.testing.assert_array_equal(pos[offset : offset + n], probe.contact_positions)
offset += n


def test_get_global_contact_positions_single_probe(probegroup):
pos = probegroup.get_global_contact_positions()
np.testing.assert_array_equal(
pos[: probegroup.probes[0].get_contact_count()], probegroup.probes[0].contact_positions
)


def test_get_global_contact_positions_3d():
pg = ProbeGroup()
for i in range(2):
probe = generate_dummy_probe().to_3d()
probe.move([i * 100, i * 80, i * 30])
pg.add_probe(probe)
pos = pg.get_global_contact_positions()
assert pos.shape[1] == 3
assert pos.shape[0] == pg.get_contact_count()


def test_get_global_contact_positions_reflects_move():
"""Positions should reflect probe movement."""
pg = ProbeGroup()
probe = generate_dummy_probe()
original_pos = probe.contact_positions.copy()
probe.move([50, 60])
pg.add_probe(probe)
pos = pg.get_global_contact_positions()
np.testing.assert_array_equal(pos, original_pos + np.array([50, 60]))


# ── copy() tests ────────────────────────────────────────────────────────────


def test_copy_returns_new_object(probegroup):
pg_copy = probegroup.copy()
assert pg_copy is not probegroup
assert len(pg_copy.probes) == len(probegroup.probes)
for orig, copied in zip(probegroup.probes, pg_copy.probes):
assert orig is not copied


def test_copy_preserves_positions(probegroup):
pg_copy = probegroup.copy()
for orig, copied in zip(probegroup.probes, pg_copy.probes):
np.testing.assert_array_equal(orig.contact_positions, copied.contact_positions)


def test_copy_preserves_device_channel_indices(probegroup):
pg_copy = probegroup.copy()
np.testing.assert_array_equal(
probegroup.get_global_device_channel_indices(),
pg_copy.get_global_device_channel_indices(),
)


def test_copy_does_not_preserve_contact_ids(probegroup):
"""Probe.copy() intentionally does not copy contact_ids."""
pg_copy = probegroup.copy()
# All contact_ids should be empty strings after copy
assert all(cid == "" for cid in pg_copy.get_global_contact_ids())


def test_copy_is_independent(probegroup):
"""Mutating the copy must not affect the original."""
original_positions = probegroup.probes[0].contact_positions.copy()
pg_copy = probegroup.copy()
pg_copy.probes[0].move([999, 999])
np.testing.assert_array_equal(probegroup.probes[0].contact_positions, original_positions)


# ── get_slice() tests ───────────────────────────────────────────────────────


def test_get_slice_by_bool(probegroup):
total = probegroup.get_contact_count()
sel = np.zeros(total, dtype=bool)
sel[:5] = True # first 5 contacts from the first probe
sliced = probegroup.get_slice(sel)
assert sliced.get_contact_count() == 5


def test_get_slice_by_index(probegroup):
indices = np.array([0, 1, 2, 33, 34]) # contacts from both probes
sliced = probegroup.get_slice(indices)
assert sliced.get_contact_count() == 5


def test_get_slice_preserves_device_channel_indices(probegroup):
indices = np.array([0, 1, 2])
sliced = probegroup.get_slice(indices)
orig_chans = probegroup.get_global_device_channel_indices()["device_channel_indices"][:3]
sliced_chans = sliced.get_global_device_channel_indices()["device_channel_indices"]
np.testing.assert_array_equal(sliced_chans, orig_chans)


def test_get_slice_preserves_positions(probegroup):
indices = np.array([0, 1, 2])
sliced = probegroup.get_slice(indices)
expected = probegroup.get_global_contact_positions()[indices]
np.testing.assert_array_equal(sliced.get_global_contact_positions(), expected)


def test_get_slice_empty_selection(probegroup):
sliced = probegroup.get_slice(np.array([], dtype=int))
assert sliced.get_contact_count() == 0
assert len(sliced.probes) == 0


def test_get_slice_wrong_bool_size(probegroup):
with pytest.raises(AssertionError):
probegroup.get_slice(np.array([True, False])) # wrong size


def test_get_slice_out_of_bounds(probegroup):
total = probegroup.get_contact_count()
with pytest.raises(AssertionError):
probegroup.get_slice(np.array([total + 10]))


def test_get_slice_all_contacts(probegroup):
"""Slicing with all contacts should give an equivalent ProbeGroup."""
total = probegroup.get_contact_count()
sliced = probegroup.get_slice(np.arange(total))
assert sliced.get_contact_count() == total
np.testing.assert_array_equal(
sliced.get_global_contact_positions(),
probegroup.get_global_contact_positions(),
)


if __name__ == "__main__":
test_probegroup()
# ~ test_probegroup_3d()
Loading