From 3dc57290dbde0aeaa5048f2301ee75015a93fe26 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 29 Dec 2025 15:43:44 +0100 Subject: [PATCH 1/6] Test IBL extractors tests failing for PI update --- src/spikeinterface/extractors/tests/test_iblextractors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/extractors/tests/test_iblextractors.py b/src/spikeinterface/extractors/tests/test_iblextractors.py index 972a8e7bb0..56d01e38cf 100644 --- a/src/spikeinterface/extractors/tests/test_iblextractors.py +++ b/src/spikeinterface/extractors/tests/test_iblextractors.py @@ -76,8 +76,8 @@ def test_offsets(self): def test_probe_representation(self): probe = self.recording.get_probe() - expected_probe_representation = "Probe - 384ch - 1shanks" - assert repr(probe) == expected_probe_representation + expected_probe_representation = "Probe - 384ch" + assert expected_probe_representation in repr(probe) def test_property_keys(self): expected_property_keys = [ From e1006bcd580c6941984bbccc56ec00dabbda9ef9 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 16 Apr 2026 11:12:57 +0200 Subject: [PATCH 2/6] Add read_kilosort4_motion function --- .../sorters/external/kilosort4.py | 51 ++++++++++++++++++- src/spikeinterface/sorters/sorterlist.py | 2 +- 2 files changed, 51 insertions(+), 2 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 8dfcc4dda6..68576cb41f 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -1,7 +1,7 @@ import warnings from packaging import version -from spikeinterface.core import write_binary_recording +from spikeinterface.core import write_binary_recording, Motion from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs from .kilosortbase import KilosortBase from spikeinterface.sorters.basesorter import get_job_kwargs @@ -484,3 +484,52 @@ def _setup_json_probe_map(cls, recording, sorter_output_folder): "n_chan": n_chan, } save_probe(probe, str(sorter_output_folder / "chanMap.json")) + + # close logger + for handler in logger.handlers.copy(): + logger.removeHandler(handler) + handler.close() + + +def read_kilosort4_motion(sorter_output_folder: str | Path, recording: BaseRecording | None = None) -> Motion: + """Reads the motion information from a Kilosort4 output folder and returns a Motion object. + + Parameters + ---------- + sorter_output_folder: str or Path + The path to the Kilosort4 output folder. + recording: BaseRecording, optional + The recording object. If provided, the temporal bins will be estimated based on the recording's + start and end times. If not provided, the temporal bins will be estimated based on the number + of batches in the ops file. + + Returns + ------- + Motion + A Motion object containing the displacement, temporal bins, and spatial bins. + + """ + sorter_output_folder = Path(sorter_output_folder) + ops_file = sorter_output_folder / "ops.npy" + if not ops_file.is_file(): + raise FileNotFoundError("'ops.npy' file not found!") + ops = np.load(ops_file, allow_pickle=True).item() + yblk = ops.get("yblk") + dshift = ops.get("dshift") + if yblk is None or dshift is None: + raise Exception("'yblk' and 'dshift' fields not found in ops file!") + displacement = dshift + yblk + spatial_bins_um = yblk + # estimate temporal bins + batch_size = ops["batch_size"] + fs = ops["fs"] + t_bin = batch_size / fs + if recording is not None: + t_start = recording.get_start_time() + t_end = recording.get_end_time() + temporal_bins_s = np.linspace(t_start + t_bin / 2, t_end - t_bin / 2) + else: + temporal_bins_s = np.arange(displacement.shape[0]) * t_bin + t_bin / 2 + + motion = Motion(displacement=displacement, temporal_bins_s=temporal_bins_s, spatial_bins_um=spatial_bins_um) + return motion diff --git a/src/spikeinterface/sorters/sorterlist.py b/src/spikeinterface/sorters/sorterlist.py index 140e0185e2..3c6289ba6d 100644 --- a/src/spikeinterface/sorters/sorterlist.py +++ b/src/spikeinterface/sorters/sorterlist.py @@ -6,7 +6,7 @@ from .external.kilosort2 import Kilosort2Sorter from .external.kilosort2_5 import Kilosort2_5Sorter from .external.kilosort3 import Kilosort3Sorter -from .external.kilosort4 import Kilosort4Sorter +from .external.kilosort4 import Kilosort4Sorter, read_kilosort4_motion from .external.pykilosort import PyKilosortSorter from .external.klusta import KlustaSorter from .external.mountainsort4 import Mountainsort4Sorter From 6d904db3150fdafce9f356394ce64dfa2c5588e0 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 16 Apr 2026 11:31:32 +0200 Subject: [PATCH 3/6] Fix imports --- src/spikeinterface/sorters/external/kilosort4.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 68576cb41f..b6bd82e370 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -1,6 +1,8 @@ import warnings +from pathlib import Path from packaging import version + from spikeinterface.core import write_binary_recording, Motion from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs from .kilosortbase import KilosortBase From bb0456a0b2f0612025432a3731ff7a238afe2ca3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 16 Apr 2026 17:03:33 +0200 Subject: [PATCH 4/6] fix imports 2 --- src/spikeinterface/sorters/external/kilosort4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index b6bd82e370..0731b0ac8f 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -3,7 +3,7 @@ from packaging import version -from spikeinterface.core import write_binary_recording, Motion +from spikeinterface.core import write_binary_recording, Motion, BaseRecording from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs from .kilosortbase import KilosortBase from spikeinterface.sorters.basesorter import get_job_kwargs From b26ab8f4d90044d78b15ed8b5b2c73a384c2935c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 16 Apr 2026 17:26:49 +0200 Subject: [PATCH 5/6] fix logger closing --- src/spikeinterface/sorters/external/kilosort4.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index 0731b0ac8f..c6b49db3b5 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -455,6 +455,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if (sorter_output_folder / "recording.dat").is_file(): (sorter_output_folder / "recording.dat").unlink() + # close logger + for handler in logger.handlers.copy(): + logger.removeHandler(handler) + handler.close() + @classmethod def _get_result_from_folder(cls, sorter_output_folder): return KilosortBase._get_result_from_folder(sorter_output_folder) @@ -487,11 +492,6 @@ def _setup_json_probe_map(cls, recording, sorter_output_folder): } save_probe(probe, str(sorter_output_folder / "chanMap.json")) - # close logger - for handler in logger.handlers.copy(): - logger.removeHandler(handler) - handler.close() - def read_kilosort4_motion(sorter_output_folder: str | Path, recording: BaseRecording | None = None) -> Motion: """Reads the motion information from a Kilosort4 output folder and returns a Motion object. From 0ea163f2070800320363cfe5e5fac95fa99ed02c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 17 Apr 2026 09:44:50 +0200 Subject: [PATCH 6/6] fix: read function displacement, add test for read_kilosort4_motion function --- .github/scripts/test_kilosort4_ci.py | 40 ++++++++++++++++++- src/spikeinterface/core/motion.py | 4 +- .../sorters/external/kilosort4.py | 7 ++-- 3 files changed, 45 insertions(+), 6 deletions(-) diff --git a/.github/scripts/test_kilosort4_ci.py b/.github/scripts/test_kilosort4_ci.py index b075704628..75bac8f03b 100644 --- a/.github/scripts/test_kilosort4_ci.py +++ b/.github/scripts/test_kilosort4_ci.py @@ -28,7 +28,8 @@ import spikeinterface.full as si from spikeinterface.core.testing import check_sortings_equal -from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter +from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter, read_kilosort4_motion +from spikeinterface.core.motion import Motion from probeinterface.io import write_prb from spikeinterface.extractors import read_kilosort_as_analyzer @@ -669,6 +670,43 @@ def monkeypatch_filter_function(self, X, ops=None, ibatch=None): assert np.allclose(results["ks"]["st"], results["si"]["st"], rtol=0, atol=1) assert np.array_equal(results["ks"]["clus"], results["si"]["clus"]) + def test_read_kilosort4_motion(self, recording_and_paths, tmp_path): + """ + Test that read_kilosort4_motion returns a Motion object whose displacement + equals dshift (not dshift + yblk), and that temporal/spatial bins are correct. + """ + recording, _ = recording_and_paths + sorter_output_dir = tmp_path / "ks4_motion_output" / "sorter_output" + + si.run_sorter( + "kilosort4", + recording, + folder=tmp_path / "ks4_motion_output", + remove_existing_folder=True, + ) + + ops = np.load(sorter_output_dir / "ops.npy", allow_pickle=True).item() + yblk = ops["yblk"] + dshift = ops["dshift"] + + # without recording: temporal bins estimated from batch count + motion = read_kilosort4_motion(sorter_output_dir) + assert isinstance(motion, Motion) + assert motion.displacement[0].shape == dshift.shape + np.testing.assert_array_equal(motion.displacement[0], dshift) + np.testing.assert_array_equal(motion.spatial_bins_um, yblk) + assert motion.temporal_bins_s[0].shape[0] == dshift.shape[0] + # displacement must be relative (not offset by spatial bin position) + assert not np.allclose(motion.displacement[0], dshift + yblk) + + # with recording: temporal bins bounded by recording times + motion_rec = read_kilosort4_motion(sorter_output_dir, recording=recording) + assert isinstance(motion_rec, Motion) + np.testing.assert_array_equal(motion_rec.displacement[0], dshift) + assert motion_rec.temporal_bins_s[0].shape[0] == dshift.shape[0] + assert motion_rec.temporal_bins_s[0][0] >= recording.get_start_time() + assert motion_rec.temporal_bins_s[0][-1] <= recording.get_end_time() + ##### Helpers ###### def _get_kilosort_native_settings(self, recording, paths, param_key, param_value): """ diff --git a/src/spikeinterface/core/motion.py b/src/spikeinterface/core/motion.py index de57b4f9a4..6a99ca29c8 100644 --- a/src/spikeinterface/core/motion.py +++ b/src/spikeinterface/core/motion.py @@ -14,7 +14,8 @@ class Motion: Parameters ---------- displacement : numpy array 2d or list of - Motion estimate in um. + Motion estimate in um, relative to the spatial_bins_um. + The first dimension is temporal bins, the second dimension is spatial bins. List is the number of segment. For each semgent : @@ -93,6 +94,7 @@ def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_inde Parameters ---------- times_s: np.array + Times at which to evaluate the motion, in seconds. This should be a one-dimensional array. locations_um: np.array Either this is a one-dimensional array (a vector of positions along self.dimension), or else a 2d array with the 2 or 3 spatial dimensions indexed along axis=1. diff --git a/src/spikeinterface/sorters/external/kilosort4.py b/src/spikeinterface/sorters/external/kilosort4.py index c6b49db3b5..016af2d392 100644 --- a/src/spikeinterface/sorters/external/kilosort4.py +++ b/src/spikeinterface/sorters/external/kilosort4.py @@ -2,6 +2,7 @@ from pathlib import Path from packaging import version +import numpy as np from spikeinterface.core import write_binary_recording, Motion, BaseRecording from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs @@ -171,7 +172,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): import time import torch - import numpy as np import logging if version.parse(cls.get_sorter_version()) < version.parse("4.0.16"): @@ -468,7 +468,6 @@ def _get_result_from_folder(cls, sorter_output_folder): def _setup_json_probe_map(cls, recording, sorter_output_folder): """Create a JSON probe map file for Kilosort4.""" from kilosort.io import save_probe - import numpy as np groups = recording.get_channel_groups() positions = np.array(recording.get_channel_locations()) @@ -520,7 +519,7 @@ def read_kilosort4_motion(sorter_output_folder: str | Path, recording: BaseRecor dshift = ops.get("dshift") if yblk is None or dshift is None: raise Exception("'yblk' and 'dshift' fields not found in ops file!") - displacement = dshift + yblk + displacement = dshift spatial_bins_um = yblk # estimate temporal bins batch_size = ops["batch_size"] @@ -529,7 +528,7 @@ def read_kilosort4_motion(sorter_output_folder: str | Path, recording: BaseRecor if recording is not None: t_start = recording.get_start_time() t_end = recording.get_end_time() - temporal_bins_s = np.linspace(t_start + t_bin / 2, t_end - t_bin / 2) + temporal_bins_s = np.linspace(t_start + t_bin / 2, t_end - t_bin / 2, displacement.shape[0]) else: temporal_bins_s = np.arange(displacement.shape[0]) * t_bin + t_bin / 2