diff --git a/.gitignore b/.gitignore index 0ee5de6..6d15af9 100644 --- a/.gitignore +++ b/.gitignore @@ -26,6 +26,7 @@ cov.xml .DS_Store uv.lock +.codex # libraries **/neuropixels_library_generated diff --git a/src/probeinterface/probegroup.py b/src/probeinterface/probegroup.py index d42906a..4271de0 100644 --- a/src/probeinterface/probegroup.py +++ b/src/probeinterface/probegroup.py @@ -78,6 +78,81 @@ def get_contact_count(self) -> int: n = sum(probe.get_contact_count() for probe in self.probes) return n + def _build_contact_vector(self) -> np.ndarray: + """ + Return the channel-ordered dense view of the probegroup, computed fresh. + + Private by convention: this method is intended for integration with SpikeInterface, + which needs a channel-ordered view for recording-facing queries. Fields and dtype + may evolve with consumer requirements, so user code should not depend on it directly. + For stable probegroup state, use the public `get_global_*` methods. + + Invariants + ---------- + - Ordering: rows are sorted ascending by `device_channel_indices` using a stable + sort. Ties preserve per-probe insertion order. + - Row count: one row per *connected* contact (`device_channel_indices >= 0`). + The returned size is generally smaller than `self.get_contact_count()` when the + probegroup has unwired contacts. This matches SpikeInterface's pre-migration + `contact_vector` convention. + - Dtype: includes `probe_index`, `x`, `y`, and `z` if `ndim == 3`. Optional fields + `shank_ids` and `contact_sides` appear only when at least one probe in the group + defines them. Consumers must guard field access accordingly. + - Raises `ValueError` on empty probegroups and on probegroups with no wired + contacts. + + This method builds a fresh array on every call. It is not cached. Consumers that + need to call it repeatedly in a hot loop should cache the result at the call site, + where the lifetime and invalidation story are local. + """ + if len(self.probes) == 0: + raise ValueError("Cannot build a contact_vector for an empty ProbeGroup") + + has_shank_ids = any(probe.shank_ids is not None for probe in self.probes) + has_contact_sides = any(probe.contact_sides is not None for probe in self.probes) + + dtype = [("probe_index", "int64"), ("x", "float64"), ("y", "float64")] + if self.ndim == 3: + dtype.append(("z", "float64")) + if has_shank_ids: + dtype.append(("shank_ids", "U64")) + if has_contact_sides: + dtype.append(("contact_sides", "U8")) + + channel_index_parts = [] + contact_vector_parts = [] + for probe_index, probe in enumerate(self.probes): + device_channel_indices = probe.device_channel_indices + if device_channel_indices is None: + continue + + device_channel_indices = np.asarray(device_channel_indices) + connected = device_channel_indices >= 0 + if not np.any(connected): + continue + + probe_vector = np.zeros(np.sum(connected), dtype=dtype) + probe_vector["probe_index"] = probe_index + probe_vector["x"] = probe.contact_positions[connected, 0] + probe_vector["y"] = probe.contact_positions[connected, 1] + if self.ndim == 3: + probe_vector["z"] = probe.contact_positions[connected, 2] + if has_shank_ids and probe.shank_ids is not None: + probe_vector["shank_ids"] = probe.shank_ids[connected] + if has_contact_sides and probe.contact_sides is not None: + probe_vector["contact_sides"] = probe.contact_sides[connected] + + channel_index_parts.append(device_channel_indices[connected]) + contact_vector_parts.append(probe_vector) + + if len(contact_vector_parts) == 0: + raise ValueError("contact_vector requires at least one wired contact") + + channel_indices = np.concatenate(channel_index_parts, axis=0) + contact_vector = np.concatenate(contact_vector_parts, axis=0) + order = np.argsort(channel_indices, kind="stable") + return contact_vector[order] + def to_numpy(self, complete: bool = False) -> np.ndarray: """ Export all probes into a numpy array. diff --git a/tests/test_probegroup.py b/tests/test_probegroup.py index c942190..6f1ed2c 100644 --- a/tests/test_probegroup.py +++ b/tests/test_probegroup.py @@ -108,6 +108,85 @@ def test_set_contact_ids_rejects_wrong_size(): probe.set_contact_ids(["a", "b", "c"]) +def test_contact_vector_orders_connected_contacts(): + from probeinterface import Probe + + probe0 = Probe(ndim=2, si_units="um") + probe0.set_contacts( + positions=np.array([[10.0, 0.0], [30.0, 0.0]]), + shapes="circle", + shape_params={"radius": 5}, + shank_ids=["s0", "s1"], + contact_sides=["front", "back"], + ) + probe0.set_device_channel_indices([2, -1]) + + probe1 = Probe(ndim=2, si_units="um") + probe1.set_contacts( + positions=np.array([[0.0, 0.0], [20.0, 0.0]]), + shapes="circle", + shape_params={"radius": 5}, + shank_ids=["s0", "s0"], + contact_sides=["front", "front"], + ) + probe1.set_device_channel_indices([0, 1]) + + probegroup = ProbeGroup() + probegroup.add_probe(probe0) + probegroup.add_probe(probe1) + + arr = probegroup._build_contact_vector() + + assert arr.dtype.names == ("probe_index", "x", "y", "shank_ids", "contact_sides") + assert arr.size == 3 + assert np.array_equal(arr["probe_index"], np.array([1, 1, 0])) + assert np.array_equal(arr["x"], np.array([0.0, 20.0, 10.0])) + assert np.array_equal(np.column_stack((arr["x"], arr["y"])), np.array([[0.0, 0.0], [20.0, 0.0], [10.0, 0.0]])) + + +def test_contact_vector_reflects_current_probe_state(): + probegroup = ProbeGroup() + probe = generate_dummy_probe() + probe.set_device_channel_indices(np.arange(probe.get_contact_count())) + probegroup.add_probe(probe) + + dense_before = probegroup._build_contact_vector() + original_positions = np.column_stack((dense_before["x"], dense_before["y"])).copy() + + probe.move([5.0, 0.0]) + + dense_after_move = probegroup._build_contact_vector() + assert dense_after_move is not dense_before + assert np.array_equal( + np.column_stack((dense_after_move["x"], dense_after_move["y"])), + original_positions + np.array([5.0, 0.0]), + ) + + probe.set_shank_ids(np.array(["a"] * probe.get_contact_count())) + dense_with_shanks = probegroup._build_contact_vector() + assert "shank_ids" in dense_with_shanks.dtype.names + + +def test_contact_vector_requires_wired_contacts(): + probegroup = ProbeGroup() + probe = generate_dummy_probe() + probegroup.add_probe(probe) + + with pytest.raises(ValueError, match="requires at least one wired contact"): + probegroup._build_contact_vector() + + +def test_contact_vector_supports_3d_positions(): + probegroup = ProbeGroup() + probe = generate_dummy_probe().to_3d() + probe.set_device_channel_indices(np.arange(probe.get_contact_count())) + probegroup.add_probe(probe) + + dense = probegroup._build_contact_vector() + assert dense.dtype.names[:4] == ("probe_index", "x", "y", "z") + assert np.column_stack((dense["x"], dense["y"], dense["z"])).shape[1] == 3 + + # ── get_global_contact_positions() tests ────────────────────────────────────