Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 39 additions & 1 deletion .github/scripts/test_kilosort4_ci.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@

import spikeinterface.full as si
from spikeinterface.core.testing import check_sortings_equal
from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter
from spikeinterface.sorters.external.kilosort4 import Kilosort4Sorter, read_kilosort4_motion
from spikeinterface.core.motion import Motion
from probeinterface.io import write_prb
from spikeinterface.extractors import read_kilosort_as_analyzer

Expand Down Expand Up @@ -669,6 +670,43 @@ def monkeypatch_filter_function(self, X, ops=None, ibatch=None):
assert np.allclose(results["ks"]["st"], results["si"]["st"], rtol=0, atol=1)
assert np.array_equal(results["ks"]["clus"], results["si"]["clus"])

def test_read_kilosort4_motion(self, recording_and_paths, tmp_path):
"""
Test that read_kilosort4_motion returns a Motion object whose displacement
equals dshift (not dshift + yblk), and that temporal/spatial bins are correct.
"""
recording, _ = recording_and_paths
sorter_output_dir = tmp_path / "ks4_motion_output" / "sorter_output"

si.run_sorter(
"kilosort4",
recording,
folder=tmp_path / "ks4_motion_output",
remove_existing_folder=True,
)

ops = np.load(sorter_output_dir / "ops.npy", allow_pickle=True).item()
yblk = ops["yblk"]
dshift = ops["dshift"]

# without recording: temporal bins estimated from batch count
motion = read_kilosort4_motion(sorter_output_dir)
assert isinstance(motion, Motion)
assert motion.displacement[0].shape == dshift.shape
np.testing.assert_array_equal(motion.displacement[0], dshift)
np.testing.assert_array_equal(motion.spatial_bins_um, yblk)
assert motion.temporal_bins_s[0].shape[0] == dshift.shape[0]
# displacement must be relative (not offset by spatial bin position)
assert not np.allclose(motion.displacement[0], dshift + yblk)

# with recording: temporal bins bounded by recording times
motion_rec = read_kilosort4_motion(sorter_output_dir, recording=recording)
assert isinstance(motion_rec, Motion)
np.testing.assert_array_equal(motion_rec.displacement[0], dshift)
assert motion_rec.temporal_bins_s[0].shape[0] == dshift.shape[0]
assert motion_rec.temporal_bins_s[0][0] >= recording.get_start_time()
assert motion_rec.temporal_bins_s[0][-1] <= recording.get_end_time()

##### Helpers ######
def _get_kilosort_native_settings(self, recording, paths, param_key, param_value):
"""
Expand Down
4 changes: 3 additions & 1 deletion src/spikeinterface/core/motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ class Motion:
Parameters
----------
displacement : numpy array 2d or list of
Motion estimate in um.
Motion estimate in um, relative to the spatial_bins_um.
The first dimension is temporal bins, the second dimension is spatial bins.
List is the number of segment.
For each semgent :

Expand Down Expand Up @@ -93,6 +94,7 @@ def get_displacement_at_time_and_depth(self, times_s, locations_um, segment_inde
Parameters
----------
times_s: np.array
Times at which to evaluate the motion, in seconds. This should be a one-dimensional array.
locations_um: np.array
Either this is a one-dimensional array (a vector of positions along self.dimension), or
else a 2d array with the 2 or 3 spatial dimensions indexed along axis=1.
Expand Down
56 changes: 53 additions & 3 deletions src/spikeinterface/sorters/external/kilosort4.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import warnings
from pathlib import Path
from packaging import version

from spikeinterface.core import write_binary_recording
import numpy as np

from spikeinterface.core import write_binary_recording, Motion, BaseRecording
from spikeinterface.sorters.basesorter import BaseSorter, get_job_kwargs
from .kilosortbase import KilosortBase
from spikeinterface.sorters.basesorter import get_job_kwargs
Expand Down Expand Up @@ -169,7 +172,6 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

import time
import torch
import numpy as np
import logging

if version.parse(cls.get_sorter_version()) < version.parse("4.0.16"):
Expand Down Expand Up @@ -453,6 +455,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if (sorter_output_folder / "recording.dat").is_file():
(sorter_output_folder / "recording.dat").unlink()

# close logger
for handler in logger.handlers.copy():
logger.removeHandler(handler)
handler.close()

@classmethod
def _get_result_from_folder(cls, sorter_output_folder):
return KilosortBase._get_result_from_folder(sorter_output_folder)
Expand All @@ -461,7 +468,6 @@ def _get_result_from_folder(cls, sorter_output_folder):
def _setup_json_probe_map(cls, recording, sorter_output_folder):
"""Create a JSON probe map file for Kilosort4."""
from kilosort.io import save_probe
import numpy as np

groups = recording.get_channel_groups()
positions = np.array(recording.get_channel_locations())
Expand All @@ -484,3 +490,47 @@ def _setup_json_probe_map(cls, recording, sorter_output_folder):
"n_chan": n_chan,
}
save_probe(probe, str(sorter_output_folder / "chanMap.json"))


def read_kilosort4_motion(sorter_output_folder: str | Path, recording: BaseRecording | None = None) -> Motion:
"""Reads the motion information from a Kilosort4 output folder and returns a Motion object.

Parameters
----------
sorter_output_folder: str or Path
The path to the Kilosort4 output folder.
recording: BaseRecording, optional
The recording object. If provided, the temporal bins will be estimated based on the recording's
start and end times. If not provided, the temporal bins will be estimated based on the number
of batches in the ops file.

Returns
-------
Motion
A Motion object containing the displacement, temporal bins, and spatial bins.

"""
sorter_output_folder = Path(sorter_output_folder)
ops_file = sorter_output_folder / "ops.npy"
if not ops_file.is_file():
raise FileNotFoundError("'ops.npy' file not found!")
ops = np.load(ops_file, allow_pickle=True).item()
yblk = ops.get("yblk")
dshift = ops.get("dshift")
if yblk is None or dshift is None:
raise Exception("'yblk' and 'dshift' fields not found in ops file!")
displacement = dshift
spatial_bins_um = yblk
# estimate temporal bins
batch_size = ops["batch_size"]
fs = ops["fs"]
t_bin = batch_size / fs
if recording is not None:
t_start = recording.get_start_time()
t_end = recording.get_end_time()
temporal_bins_s = np.linspace(t_start + t_bin / 2, t_end - t_bin / 2, displacement.shape[0])
else:
temporal_bins_s = np.arange(displacement.shape[0]) * t_bin + t_bin / 2

motion = Motion(displacement=displacement, temporal_bins_s=temporal_bins_s, spatial_bins_um=spatial_bins_um)
return motion
2 changes: 1 addition & 1 deletion src/spikeinterface/sorters/sorterlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .external.kilosort2 import Kilosort2Sorter
from .external.kilosort2_5 import Kilosort2_5Sorter
from .external.kilosort3 import Kilosort3Sorter
from .external.kilosort4 import Kilosort4Sorter
from .external.kilosort4 import Kilosort4Sorter, read_kilosort4_motion
from .external.pykilosort import PyKilosortSorter
from .external.klusta import KlustaSorter
from .external.mountainsort4 import Mountainsort4Sorter
Expand Down
Loading